Unverified Commit 52ee5ea0 authored by Teddy Do's avatar Teddy Do Committed by GitHub
Browse files

Fix bugs in permutation custom partitioning (#2617)



* Use correct block size for workspace in row id map creation, also shard workspace correctly based on 2nd dim of routing_map/row_id map
Signed-off-by: default avatarDoubleCheeseCheetos <hanhdp99@gmail.com>

* reduce size of largest test case on single_GPU scenario to fit on L40 and A100 in CI line up
Signed-off-by: default avatartdophung <hanhdp99@gmail.com>

---------
Signed-off-by: default avatarDoubleCheeseCheetos <hanhdp99@gmail.com>
Signed-off-by: default avatartdophung <hanhdp99@gmail.com>
Co-authored-by: default avatarDoubleCheeseCheetos <hanhdp99@gmail.com>
parent c6a92a4d
...@@ -23,7 +23,7 @@ ALL_DISPATCH_COMBINE_CASES = [ ...@@ -23,7 +23,7 @@ ALL_DISPATCH_COMBINE_CASES = [
(128, 5, 128, 3), (128, 5, 128, 3),
(1024, 8, 128, 8), (1024, 8, 128, 8),
(4096, 32, 1280, 2), (4096, 32, 1280, 2),
(4096, 256, 4096, 6), (4096, 64, 4096, 6),
] ]
DISPATCH_COMBINE_CASES = { DISPATCH_COMBINE_CASES = {
"L0": ALL_DISPATCH_COMBINE_CASES[0:2], "L0": ALL_DISPATCH_COMBINE_CASES[0:2],
...@@ -44,7 +44,7 @@ ALL_DISPATCH_COMBINE_PADDING_CASES = [ ...@@ -44,7 +44,7 @@ ALL_DISPATCH_COMBINE_PADDING_CASES = [
(128, 5, 128, 3, 8), (128, 5, 128, 3, 8),
(1024, 8, 128, 8, 16), (1024, 8, 128, 8, 16),
(4096, 32, 1280, 2, 128), (4096, 32, 1280, 2, 128),
(4096, 256, 4096, 6, 16), (4096, 64, 4096, 6, 16),
] ]
DISPATCH_COMBINE_PADDING_CASES = { DISPATCH_COMBINE_PADDING_CASES = {
"L0": ALL_DISPATCH_COMBINE_PADDING_CASES[0:2], "L0": ALL_DISPATCH_COMBINE_PADDING_CASES[0:2],
......
...@@ -65,8 +65,6 @@ class RowIdMapPass1Primitive(BasePrimitive): ...@@ -65,8 +65,6 @@ class RowIdMapPass1Primitive(BasePrimitive):
@staticmethod @staticmethod
def abstract(routing_map_aval, *, num_tokens, num_experts, block_size): def abstract(routing_map_aval, *, num_tokens, num_experts, block_size):
"""Shape/dtype inference for pass 1.""" """Shape/dtype inference for pass 1."""
del block_size # Only affects grid, not output shape
assert routing_map_aval.shape == ( assert routing_map_aval.shape == (
num_tokens, num_tokens,
num_experts, num_experts,
...@@ -75,7 +73,7 @@ class RowIdMapPass1Primitive(BasePrimitive): ...@@ -75,7 +73,7 @@ class RowIdMapPass1Primitive(BasePrimitive):
row_id_map_shape = (num_tokens, num_experts * 2 + 1) row_id_map_shape = (num_tokens, num_experts * 2 + 1)
workspace_shape = ( workspace_shape = (
num_experts, num_experts,
triton.cdiv(num_tokens, DEFAULT_BLOCK_SIZE), triton.cdiv(num_tokens, block_size),
) )
return ( return (
...@@ -134,9 +132,10 @@ class RowIdMapPass1Primitive(BasePrimitive): ...@@ -134,9 +132,10 @@ class RowIdMapPass1Primitive(BasePrimitive):
desc="RowIdMapPass1.row_id_map_sharding", desc="RowIdMapPass1.row_id_map_sharding",
) )
# Workspace shape: (num_experts, cdiv(num_tokens, BLOCK_SIZE)) # Workspace shape: (num_experts, cdiv(num_tokens, BLOCK_SIZE))
# Second dim depends on num_tokens, so it must be sharded on the same axis as tokens
workspace_sharding = NamedSharding( workspace_sharding = NamedSharding(
mesh, mesh,
PartitionSpec(None, None), PartitionSpec(None, routing_map_spec[0]),
desc="RowIdMapPass1.workspace_sharding", desc="RowIdMapPass1.workspace_sharding",
) )
return [row_id_map_sharding, workspace_sharding] return [row_id_map_sharding, workspace_sharding]
...@@ -156,9 +155,11 @@ class RowIdMapPass1Primitive(BasePrimitive): ...@@ -156,9 +155,11 @@ class RowIdMapPass1Primitive(BasePrimitive):
PartitionSpec(routing_map_spec[0], None), PartitionSpec(routing_map_spec[0], None),
desc="RowIdMapPass1.row_id_map_sharding", desc="RowIdMapPass1.row_id_map_sharding",
) )
# Workspace shape: (num_experts, cdiv(num_tokens, BLOCK_SIZE))
# Second dim depends on num_tokens, so it must be sharded on the same axis as tokens
workspace_sharding = NamedSharding( workspace_sharding = NamedSharding(
mesh, mesh,
PartitionSpec(None, None), PartitionSpec(None, routing_map_spec[0]),
desc="RowIdMapPass1.workspace_sharding", desc="RowIdMapPass1.workspace_sharding",
) )
out_shardings = [row_id_map_sharding, workspace_sharding] out_shardings = [row_id_map_sharding, workspace_sharding]
...@@ -186,7 +187,8 @@ class RowIdMapPass1Primitive(BasePrimitive): ...@@ -186,7 +187,8 @@ class RowIdMapPass1Primitive(BasePrimitive):
# Note: row_id_cols != experts since it's num_experts * 2 + 1 # Note: row_id_cols != experts since it's num_experts * 2 + 1
row_id_map_spec = (f"{prefix}_tokens", f"{prefix}_row_id_cols") row_id_map_spec = (f"{prefix}_tokens", f"{prefix}_row_id_cols")
# workspace shape: (num_experts, cdiv(num_tokens, BLOCK_SIZE)) # workspace shape: (num_experts, cdiv(num_tokens, BLOCK_SIZE))
workspace_spec = (f"{prefix}_experts", f"{prefix}_ws_blocks") # Second dim depends on num_tokens, so use same factor to ensure same sharding
workspace_spec = (f"{prefix}_experts", f"{prefix}_tokens")
return SdyShardingRule((input_spec,), (row_id_map_spec, workspace_spec)) return SdyShardingRule((input_spec,), (row_id_map_spec, workspace_spec))
...@@ -208,10 +210,9 @@ class RowIdMapPass2Primitive(BasePrimitive): ...@@ -208,10 +210,9 @@ class RowIdMapPass2Primitive(BasePrimitive):
def abstract(row_id_map_aval, workspace_aval, *, num_tokens, num_experts, block_size): def abstract(row_id_map_aval, workspace_aval, *, num_tokens, num_experts, block_size):
"""Shape/dtype inference for pass 2 (in-place operation).""" """Shape/dtype inference for pass 2 (in-place operation)."""
del row_id_map_aval, workspace_aval del row_id_map_aval, workspace_aval
del block_size
row_id_map_shape = (num_tokens, num_experts * 2 + 1) row_id_map_shape = (num_tokens, num_experts * 2 + 1)
workspace_shape = (num_experts, triton.cdiv(num_tokens, DEFAULT_BLOCK_SIZE)) workspace_shape = (num_experts, triton.cdiv(num_tokens, block_size))
return ( return (
jax.core.ShapedArray(row_id_map_shape, jnp.int32), jax.core.ShapedArray(row_id_map_shape, jnp.int32),
...@@ -270,9 +271,11 @@ class RowIdMapPass2Primitive(BasePrimitive): ...@@ -270,9 +271,11 @@ class RowIdMapPass2Primitive(BasePrimitive):
PartitionSpec(*row_id_map_spec), PartitionSpec(*row_id_map_spec),
desc="RowIdMapPass2.row_id_map_sharding", desc="RowIdMapPass2.row_id_map_sharding",
) )
# Workspace shape: (num_experts, cdiv(num_tokens, BLOCK_SIZE))
# Second dim depends on num_tokens, so it must be sharded on the same axis as tokens
workspace_sharding = NamedSharding( workspace_sharding = NamedSharding(
mesh, mesh,
PartitionSpec(None, None), PartitionSpec(None, row_id_map_spec[0]),
desc="RowIdMapPass2.workspace_sharding", desc="RowIdMapPass2.workspace_sharding",
) )
return [row_id_map_sharding, workspace_sharding] return [row_id_map_sharding, workspace_sharding]
...@@ -292,9 +295,11 @@ class RowIdMapPass2Primitive(BasePrimitive): ...@@ -292,9 +295,11 @@ class RowIdMapPass2Primitive(BasePrimitive):
PartitionSpec(*row_id_map_spec), PartitionSpec(*row_id_map_spec),
desc="RowIdMapPass2.row_id_map_sharding", desc="RowIdMapPass2.row_id_map_sharding",
) )
# Workspace shape: (num_experts, cdiv(num_tokens, BLOCK_SIZE))
# Second dim depends on num_tokens, so it must be sharded on the same axis as tokens
workspace_sharding = NamedSharding( workspace_sharding = NamedSharding(
mesh, mesh,
PartitionSpec(None, None), PartitionSpec(None, row_id_map_spec[0]),
desc="RowIdMapPass2.workspace_sharding", desc="RowIdMapPass2.workspace_sharding",
) )
out_shardings = [row_id_map_sharding, workspace_sharding] out_shardings = [row_id_map_sharding, workspace_sharding]
...@@ -317,7 +322,9 @@ class RowIdMapPass2Primitive(BasePrimitive): ...@@ -317,7 +322,9 @@ class RowIdMapPass2Primitive(BasePrimitive):
del num_tokens, num_experts, block_size, mesh, value_types, result_types del num_tokens, num_experts, block_size, mesh, value_types, result_types
prefix = "RowIdMapPass2" prefix = "RowIdMapPass2"
row_id_map_spec = (f"{prefix}_tokens", f"{prefix}_cols") row_id_map_spec = (f"{prefix}_tokens", f"{prefix}_cols")
workspace_spec = (f"{prefix}_ws_experts", f"{prefix}_ws_blocks") # workspace shape: (num_experts, cdiv(num_tokens, BLOCK_SIZE))
# Second dim depends on num_tokens, so use same factor to ensure same sharding
workspace_spec = (f"{prefix}_ws_experts", f"{prefix}_tokens")
return SdyShardingRule((row_id_map_spec, workspace_spec), (row_id_map_spec, workspace_spec)) return SdyShardingRule((row_id_map_spec, workspace_spec), (row_id_map_spec, workspace_spec))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment