"vscode:/vscode.git/clone" did not exist on "cbce23fb86705aff7f9f1294a295027d7ed674ef"
Unverified Commit 21162d06 authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

layer number bug fixes (#326)



* layer number bug fixes
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 706d7da3
...@@ -977,7 +977,7 @@ class MultiHeadAttention(torch.nn.Module): ...@@ -977,7 +977,7 @@ class MultiHeadAttention(torch.nn.Module):
bias: bool = True, bias: bool = True,
) -> None: ) -> None:
super().__init__() super().__init__()
self.layer_number = (layer_number,) self.layer_number = layer_number
self.input_layernorm = input_layernorm self.input_layernorm = input_layernorm
self.attention_type = attention_type self.attention_type = attention_type
self.get_rng_state_tracker = get_rng_state_tracker self.get_rng_state_tracker = get_rng_state_tracker
...@@ -991,9 +991,9 @@ class MultiHeadAttention(torch.nn.Module): ...@@ -991,9 +991,9 @@ class MultiHeadAttention(torch.nn.Module):
qkv_weight_interleaved = False qkv_weight_interleaved = False
self.qkv_weight_interleaved = qkv_weight_interleaved self.qkv_weight_interleaved = qkv_weight_interleaved
assert ( assert attention_type in AttnTypes, f"attention_type {attention_type} not supported"
attention_type in AttnTypes if layer_number is not None:
), f"attention_type {attention_type} not supported" assert layer_number > 0, "layer_number must be a positive integer"
tp_size = tp_size if tp_group is None else get_distributed_world_size(tp_group) tp_size = tp_size if tp_group is None else get_distributed_world_size(tp_group)
self.tp_size = tp_size self.tp_size = tp_size
...@@ -1090,7 +1090,7 @@ class MultiHeadAttention(torch.nn.Module): ...@@ -1090,7 +1090,7 @@ class MultiHeadAttention(torch.nn.Module):
attn_mask_type=attn_mask_type, attn_mask_type=attn_mask_type,
sequence_parallel=sequence_parallel, sequence_parallel=sequence_parallel,
tp_group=tp_group, tp_group=tp_group,
layer_number=layer_number, layer_number=self.layer_number,
) )
# Linear # Linear
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment