Unverified Commit bacefdbb authored by Sudhakar Singh's avatar Sudhakar Singh Committed by GitHub
Browse files

Fix failing CI due to PR #557 merge (#616)



fix failing tests due to PR #557
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>
Co-authored-by: default avatarcyanguwa <8636796+cyanguwa@users.noreply.github.com>
parent e4f506a0
...@@ -1225,7 +1225,7 @@ def test_transformer_layer_hidden_states_format(dtype, bs, model): ...@@ -1225,7 +1225,7 @@ def test_transformer_layer_hidden_states_format(dtype, bs, model):
kv_channels=config.embed, kv_channels=config.embed,
apply_residual_connection_post_layernorm=False, apply_residual_connection_post_layernorm=False,
output_layernorm=False, output_layernorm=False,
hidden_states_format="sbhd" attn_input_format="sbhd"
) )
.to(dtype=dtype) .to(dtype=dtype)
.cuda() .cuda()
...@@ -1248,7 +1248,7 @@ def test_transformer_layer_hidden_states_format(dtype, bs, model): ...@@ -1248,7 +1248,7 @@ def test_transformer_layer_hidden_states_format(dtype, bs, model):
kv_channels=config.embed, kv_channels=config.embed,
apply_residual_connection_post_layernorm=False, apply_residual_connection_post_layernorm=False,
output_layernorm=False, output_layernorm=False,
hidden_states_format="bshd" attn_input_format="bshd"
) )
.to(dtype=dtype) .to(dtype=dtype)
.cuda() .cuda()
......
...@@ -1034,7 +1034,11 @@ def _rotate_half(x: torch.Tensor) -> torch.Tensor: ...@@ -1034,7 +1034,11 @@ def _rotate_half(x: torch.Tensor) -> torch.Tensor:
return torch.cat((-x2, x1), dim=-1) return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(t: torch.Tensor, freqs: torch.Tensor, tensor_format: str = "sbhd") -> torch.Tensor: def apply_rotary_pos_emb(
t: torch.Tensor,
freqs: torch.Tensor,
tensor_format: str = "sbhd"
) -> torch.Tensor:
""" """
Parameters Parameters
---------- ----------
...@@ -1056,8 +1060,10 @@ def apply_rotary_pos_emb(t: torch.Tensor, freqs: torch.Tensor, tensor_format: st ...@@ -1056,8 +1060,10 @@ def apply_rotary_pos_emb(t: torch.Tensor, freqs: torch.Tensor, tensor_format: st
# Only apply the rotary embeddings up to the sequence length of the running # Only apply the rotary embeddings up to the sequence length of the running
# input. # input.
assert cur_seq_len <= max_seq_len, (f"Rotary Embeddings only supported " if cur_seq_len > max_seq_len:
"upto {max_seq_len} sequence length!") raise Exception(f"Rotary Embeddings only supported upto {max_seq_len} "
"sequence length!")
freqs = freqs[:cur_seq_len].to(t.dtype) freqs = freqs[:cur_seq_len].to(t.dtype)
if tensor_format == "bshd": if tensor_format == "bshd":
freqs = freqs.transpose(0,1) # [seq, 1, 1, dim] -> [1, seq, 1, dim] freqs = freqs.transpose(0,1) # [seq, 1, 1, dim] -> [1, seq, 1, dim]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment