Commit 3138cbdd authored by wooway777's avatar wooway777
Browse files

issue/824 - removed redundant batch handling

parent 215d1932
......@@ -20,16 +20,6 @@ def rope(
) -> Tensor:
r"""Rotary Position Embedding(RoPE)."""
bs, seq_len, num_heads, head_dim = x.shape
x_stride = x.stride()
assert seq_len * x_stride[1] == x_stride[0], (
"x need to be continuous in dim=0 and dim=1"
)
x = x.view((bs * seq_len, num_heads, head_dim))
bs, num = pos_ids.shape
pos_ids = pos_ids.view((bs * num,))
if out is None:
return Tensor(
_infinicore.rope(
......@@ -39,9 +29,8 @@ def rope(
cos_table._underlying,
algo,
)
).view((bs, seq_len, num_heads, head_dim))
)
out = out.view((bs * seq_len, num_heads, head_dim))
_infinicore.rope_(
out._underlying,
x._underlying,
......@@ -50,4 +39,4 @@ def rope(
cos_table._underlying,
algo,
)
return out.view((bs, seq_len, num_heads, head_dim))
return out
......@@ -22,11 +22,85 @@ import infinicore
_TEST_CASES_DATA = [
# bs, seq_len, num, head_dim, Algo
(1, 1, 1, 64, RopeAlgo.GPT_NEOX),
(1, 5, 32, 64, RopeAlgo.GPT_NEOX),
(1, 1, 1, 128, RopeAlgo.GPT_J),
(1, 10, 1, 64, RopeAlgo.GPT_J),
# bs, seq_len, num, head_dim, src strides, dst strides, Algo
(1, 1, 1, 64, None, None, RopeAlgo.GPT_NEOX),
(1, 5, 32, 64, None, None, RopeAlgo.GPT_NEOX),
(1, 1, 1, 128, None, None, RopeAlgo.GPT_J),
(1, 10, 1, 64, None, None, RopeAlgo.GPT_J),
(2, 20, 16, 128, None, None, RopeAlgo.GPT_NEOX),
(4, 50, 32, 256, None, None, RopeAlgo.GPT_J),
(
2,
20,
16,
128,
(20 * 16 * 128 * 16, 16 * 128 * 4, 128 * 2, 1),
(20 * 16 * 128 * 16, 16 * 128 * 4, 128 * 2, 1),
RopeAlgo.GPT_NEOX,
),
(
2,
20,
16,
128,
(20 * 16 * 128 * 16, 16 * 128 * 4, 128 * 2, 1),
(20 * 16 * 128 * 16, 16 * 128 * 4, 128 * 2, 1),
RopeAlgo.GPT_J,
),
(
4,
50,
32,
256,
(50 * 32 * 256 * 16, 32 * 256 * 4, 256 * 2, 1),
(50 * 32 * 256 * 36, 32 * 256 * 6, 256 * 3, 1),
RopeAlgo.GPT_NEOX,
),
(
4,
50,
32,
256,
(50 * 32 * 256 * 16, 32 * 256 * 4, 256 * 2, 1),
(50 * 32 * 256 * 36, 32 * 256 * 6, 256 * 3, 1),
RopeAlgo.GPT_J,
),
(
32,
64,
8,
128,
(64 * 8 * 128 * 16, 8 * 128 * 4, 128 * 2, 1),
(64 * 8 * 128 * 16, 8 * 128 * 4, 128 * 2, 1),
RopeAlgo.GPT_NEOX,
),
(
32,
64,
8,
128,
(64 * 8 * 128 * 16, 8 * 128 * 4, 128 * 2, 1),
(64 * 8 * 128 * 16, 8 * 128 * 4, 128 * 2, 1),
RopeAlgo.GPT_J,
),
(
64,
128,
32,
64,
(128 * 32 * 64 * 16, 32 * 64 * 4, 64 * 2, 1),
(128 * 32 * 64 * 36, 32 * 64 * 6, 64 * 3, 1),
RopeAlgo.GPT_NEOX,
),
(
64,
128,
32,
64,
(128 * 32 * 64 * 16, 32 * 64 * 4, 64 * 2, 1),
(128 * 32 * 64 * 36, 32 * 64 * 6, 64 * 3, 1),
RopeAlgo.GPT_J,
),
]
# Tolerance configuration
......@@ -49,7 +123,8 @@ def parse_test_cases():
for data in _TEST_CASES_DATA:
bs, seq_len, num, head_dim = data[0], data[1], data[2], data[3]
algo = data[4]
src_strides, dst_strides = data[4], data[5]
algo = data[6]
# Determine shapes based on batch dimension
out_shape = (bs, seq_len, num, head_dim)
......@@ -58,15 +133,16 @@ def parse_test_cases():
cos_table_shape = (seq_len, head_dim // 2)
# Check if tensors support in-place operations
c_supports_inplace = not is_broadcast(out_shape)
# x tensor supports in-place if it's not a broadcasted tensor
x_supports_inplace = not is_broadcast(src_strides)
# Generate test cases for all data types
for dtype in _TENSOR_DTYPES:
tolerance = _TOLERANCE_MAP.get(dtype, {"atol": 0, "rtol": 1e-3})
# Create typed tensor specs
out_spec = TensorSpec.from_tensor(out_shape, None, dtype)
x_spec = TensorSpec.from_tensor(x_shape, None, dtype)
out_spec = TensorSpec.from_tensor(out_shape, dst_strides, dtype)
x_spec = TensorSpec.from_tensor(x_shape, src_strides, dtype)
sin_table_spec = TensorSpec.from_tensor(sin_table_shape, None, dtype)
cos_table_spec = TensorSpec.from_tensor(cos_table_shape, None, dtype)
......@@ -83,7 +159,7 @@ def parse_test_cases():
)
# Test Case 2: In-place with explicit output tensor
if c_supports_inplace:
if dst_strides is None or not is_broadcast(dst_strides):
test_cases.append(
TestCase(
inputs=[x_spec, sin_table_spec, cos_table_spec],
......@@ -95,6 +171,19 @@ def parse_test_cases():
)
)
# Test Case 3: In-place on input tensor (x)
if x_supports_inplace:
test_cases.append(
TestCase(
inputs=[x_spec, sin_table_spec, cos_table_spec],
kwargs={"algo": algo, "out": 0}, # Use index 0 for first input
output_spec=None,
comparison_target=0, # Compare first input (x tensor)
tolerance=tolerance,
description=f"Rope - INPLACE(x)",
)
)
return test_cases
......@@ -107,15 +196,22 @@ def rotary_embedding(t, sin, cos, algo, *, out=None):
return t_out_1, t_out_2
ans = t.clone()
# If out parameter is provided and it's the same as input t, operate in-place
if out is not None:
if out.data_ptr() == t.data_ptr():
ans = t # Use the same tensor for in-place operation
else:
ans = out # Use provided output tensor
else:
ans = t.clone()
dh = t.shape[-1]
dt = t.dtype
assert dh % 2 == 0, "Embedding dimension must be even."
if RopeAlgo.GPT_J == algo:
t_even = t[..., 0::2] # [seq_len, n_head, dh // 2]
t_odd = t[..., 1::2] # [seq_len, n_head, dh // 2]
t_even = t[..., 0::2] # [bs, seq_len, n_head, dh // 2]
t_odd = t[..., 1::2] # [bs, seq_len, n_head, dh // 2]
t_out_even, t_out_odd = _torch_rope(sin, cos, t_even, t_odd)
......@@ -131,9 +227,10 @@ def rotary_embedding(t, sin, cos, algo, *, out=None):
ans[..., :half_dim] = t_out_first.to(dt)
ans[..., half_dim:] = t_out_second.to(dt)
else:
raise KeyError("error Algo ")
raise KeyError("Unsupported RoPE algorithm")
if out is not None:
# If operating in-place on t, we don't need to copy back
if out is not None and out.data_ptr() != t.data_ptr():
out.copy_(ans)
return out
return ans
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment