"vscode:/vscode.git/clone" did not exist on "655a5e48df3937bf793add53aa95ce0c992a24c6"
Unverified Commit 94a426b0 authored by Charlene Yang's avatar Charlene Yang Committed by GitHub
Browse files

[PyTorch] Fix tp_group_initialized error (#939)



fix tp_initialized error
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>
parent 29e8bfc9
...@@ -4504,8 +4504,13 @@ class DotProductAttention(TransformerEngineBaseModule): ...@@ -4504,8 +4504,13 @@ class DotProductAttention(TransformerEngineBaseModule):
self.attn_mask_type = attn_mask_type self.attn_mask_type = attn_mask_type
self.window_size = window_size self.window_size = window_size
self.window_size = check_set_window_size(attn_mask_type, self.window_size) self.window_size = check_set_window_size(attn_mask_type, self.window_size)
self.tp_size = tp_size if tp_group is None else get_distributed_world_size(tp_group) if tp_group is None:
self.tp_group = tp_group self.tp_size = tp_size
if tp_size == 1:
self.set_tensor_parallel_group(tp_group)
else:
self.tp_size = get_distributed_world_size(tp_group)
self.set_tensor_parallel_group(tp_group)
self.get_rng_state_tracker = get_rng_state_tracker self.get_rng_state_tracker = get_rng_state_tracker
self.num_attention_heads = num_attention_heads self.num_attention_heads = num_attention_heads
self.layer_number = 1 if layer_number is None else layer_number self.layer_number = 1 if layer_number is None else layer_number
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment