Commit 0c4c9aa6 authored by BigOneLiXiaoMing's avatar BigOneLiXiaoMing Committed by Frank Lee
Browse files

[NFC] polish colossalai/nn/_ops/embedding.py code style (#1561)

parent 08815f0e
...@@ -113,8 +113,7 @@ def colo_embedding(input_tensor: GeneralTensor, ...@@ -113,8 +113,7 @@ def colo_embedding(input_tensor: GeneralTensor,
if not weight.has_compute_spec(): # No Model Parallel Applied if not weight.has_compute_spec(): # No Model Parallel Applied
assert weight.is_replicate(), 'Invalid weight spec for native embedding op' assert weight.is_replicate(), 'Invalid weight spec for native embedding op'
return ColoTensor.from_torch_tensor( return ColoTensor.from_torch_tensor(tensor=F.embedding(input_tensor,
tensor=F.embedding(input_tensor,
weight, weight,
padding_idx=padding_idx, padding_idx=padding_idx,
max_norm=max_norm, max_norm=max_norm,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment