Unverified Commit bacefdbb authored by Sudhakar Singh's avatar Sudhakar Singh Committed by GitHub
Browse files

Fix failing CI due to PR #557 merge (#616)



fix failing tests due to PR #557
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>
Co-authored-by: default avatarcyanguwa <8636796+cyanguwa@users.noreply.github.com>
parent e4f506a0
......@@ -1225,7 +1225,7 @@ def test_transformer_layer_hidden_states_format(dtype, bs, model):
kv_channels=config.embed,
apply_residual_connection_post_layernorm=False,
output_layernorm=False,
hidden_states_format="sbhd"
attn_input_format="sbhd"
)
.to(dtype=dtype)
.cuda()
......@@ -1248,7 +1248,7 @@ def test_transformer_layer_hidden_states_format(dtype, bs, model):
kv_channels=config.embed,
apply_residual_connection_post_layernorm=False,
output_layernorm=False,
hidden_states_format="bshd"
attn_input_format="bshd"
)
.to(dtype=dtype)
.cuda()
......
......@@ -1034,7 +1034,11 @@ def _rotate_half(x: torch.Tensor) -> torch.Tensor:
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(t: torch.Tensor, freqs: torch.Tensor, tensor_format: str = "sbhd") -> torch.Tensor:
def apply_rotary_pos_emb(
t: torch.Tensor,
freqs: torch.Tensor,
tensor_format: str = "sbhd"
) -> torch.Tensor:
"""
Parameters
----------
......@@ -1056,8 +1060,10 @@ def apply_rotary_pos_emb(t: torch.Tensor, freqs: torch.Tensor, tensor_format: st
# Only apply the rotary embeddings up to the sequence length of the running
# input.
assert cur_seq_len <= max_seq_len, (f"Rotary Embeddings only supported "
"upto {max_seq_len} sequence length!")
if cur_seq_len > max_seq_len:
raise Exception(f"Rotary Embeddings only supported upto {max_seq_len} "
"sequence length!")
freqs = freqs[:cur_seq_len].to(t.dtype)
if tensor_format == "bshd":
freqs = freqs.transpose(0,1) # [seq, 1, 1, dim] -> [1, seq, 1, dim]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment