"git@developer.sourcefind.cn:OpenDAS/colossalai.git" did not exist on "424629fea023a83aa84eacf55afc8007314d9f54"
Commit bd2d7898 authored by Maruyama_Aya's avatar Maruyama_Aya Committed by Frank Lee
Browse files

[NFC] polish colossalai/nn/_ops/embedding_bag.py code style (#1552)

parent 73e9eb13
...@@ -90,22 +90,21 @@ def colo_embedding_bag(input_tensor: GeneralTensor, ...@@ -90,22 +90,21 @@ def colo_embedding_bag(input_tensor: GeneralTensor,
# Handle differen parallel actions. # Handle differen parallel actions.
if not weight.has_compute_spec(): # No Model Parallel Applied if not weight.has_compute_spec(): # No Model Parallel Applied
assert weight.is_replicate(), 'Invalid weight spec for native embedding op' assert weight.is_replicate(), 'Invalid weight spec for native embedding op'
return ColoTensor.from_torch_tensor( return ColoTensor.from_torch_tensor(tensor=F.embedding_bag(input_tensor,
tensor=F.embedding_bag(input_tensor, weight,
weight, offsets=offsets,
offsets=offsets, max_norm=max_norm,
max_norm=max_norm, norm_type=norm_type,
norm_type=norm_type, scale_grad_by_freq=scale_grad_by_freq,
scale_grad_by_freq=scale_grad_by_freq, mode=mode,
mode=mode, sparse=sparse,
sparse=sparse, per_sample_weights=per_sample_weights,
per_sample_weights=per_sample_weights, include_last_offset=include_last_offset,
include_last_offset=include_last_offset, padding_idx=padding_idx),
padding_idx=padding_idx), spec=ColoTensorSpec(weight.get_process_group()))
spec=ColoTensorSpec(weight.get_process_group())) elif weight.has_compute_pattern(ComputePattern.TP1D): # Single Model Parallel Applied
elif weight.has_compute_pattern(ComputePattern.TP1D): # Single Model Parallel Applied
if weight.is_shard_1dcol(): if weight.is_shard_1dcol():
tp_mode = 'col' tp_mode = 'col'
else: else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment