Unverified Commit 21162d06 authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

layer number bug fixes (#326)



* layer number bug fixes
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 706d7da3
......@@ -977,7 +977,7 @@ class MultiHeadAttention(torch.nn.Module):
bias: bool = True,
) -> None:
super().__init__()
self.layer_number = (layer_number,)
self.layer_number = layer_number
self.input_layernorm = input_layernorm
self.attention_type = attention_type
self.get_rng_state_tracker = get_rng_state_tracker
......@@ -991,9 +991,9 @@ class MultiHeadAttention(torch.nn.Module):
qkv_weight_interleaved = False
self.qkv_weight_interleaved = qkv_weight_interleaved
assert (
attention_type in AttnTypes
), f"attention_type {attention_type} not supported"
assert attention_type in AttnTypes, f"attention_type {attention_type} not supported"
if layer_number is not None:
assert layer_number > 0, "layer_number must be a positive integer"
tp_size = tp_size if tp_group is None else get_distributed_world_size(tp_group)
self.tp_size = tp_size
......@@ -1090,7 +1090,7 @@ class MultiHeadAttention(torch.nn.Module):
attn_mask_type=attn_mask_type,
sequence_parallel=sequence_parallel,
tp_group=tp_group,
layer_number=layer_number,
layer_number=self.layer_number,
)
# Linear
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment