Unverified Commit dc0f7ccf authored by wchen61's avatar wchen61 Committed by GitHub
Browse files

[BugFix] Enhance test_pos_encoding to support execution on multi-devices (#13187)


Signed-off-by: default avatarwchen61 <wchen61@foxmail.com>
parent d3d547e0
...@@ -70,7 +70,7 @@ def test_rotary_embedding( ...@@ -70,7 +70,7 @@ def test_rotary_embedding(
if rotary_dim is None: if rotary_dim is None:
rotary_dim = head_size rotary_dim = head_size
rope = get_rope(head_size, rotary_dim, max_position, base, is_neox_style) rope = get_rope(head_size, rotary_dim, max_position, base, is_neox_style)
rope = rope.to(dtype=dtype) rope = rope.to(dtype=dtype, device=torch.get_default_device())
positions = torch.randint(0, max_position, (batch_size, seq_len)) positions = torch.randint(0, max_position, (batch_size, seq_len))
query_shape = tensor_shape_fn(batch_size, seq_len, num_heads, head_size) query_shape = tensor_shape_fn(batch_size, seq_len, num_heads, head_size)
...@@ -125,7 +125,7 @@ def test_batched_rotary_embedding( ...@@ -125,7 +125,7 @@ def test_batched_rotary_embedding(
"rope_type": "linear", "rope_type": "linear",
"factor": (1, ) "factor": (1, )
}) })
rope = rope.to(dtype=dtype) rope = rope.to(dtype=dtype, device=torch.get_default_device())
positions = torch.randint(0, max_position, (batch_size, seq_len)) positions = torch.randint(0, max_position, (batch_size, seq_len))
query_shape = tensor_shape_fn(batch_size, seq_len, num_heads, head_size) query_shape = tensor_shape_fn(batch_size, seq_len, num_heads, head_size)
...@@ -184,7 +184,7 @@ def test_batched_rotary_embedding_multi_lora( ...@@ -184,7 +184,7 @@ def test_batched_rotary_embedding_multi_lora(
"rope_type": "linear", "rope_type": "linear",
"factor": tuple(scaling_factors) "factor": tuple(scaling_factors)
}) })
rope = rope.to(dtype=dtype) rope = rope.to(dtype=dtype, device=torch.get_default_device())
positions = torch.randint(0, max_position, (batch_size, seq_len)) positions = torch.randint(0, max_position, (batch_size, seq_len))
query = torch.randn(batch_size, query = torch.randn(batch_size,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment