Unverified Commit f62cad90 authored by Michael Goldfarb's avatar Michael Goldfarb Committed by GitHub
Browse files

Fix sharding of segment position to match id in ring attention. (#2349)

parent 4ff3eed1
...@@ -1784,6 +1784,9 @@ class FusedRingAttnFwdPrimitive(FusedAttnFwdPrimitive): ...@@ -1784,6 +1784,9 @@ class FusedRingAttnFwdPrimitive(FusedAttnFwdPrimitive):
) )
arg_shardings = [arg_i.sharding for arg_i in arg_infos] arg_shardings = [arg_i.sharding for arg_i in arg_infos]
arg_shardings[4] = seed_sharding arg_shardings[4] = seed_sharding
# Ensure segment_pos gets same sharding as ID.
arg_shardings[-1] = arg_shardings[-3]
arg_shardings[-2] = arg_shardings[-4]
arg_shardings = tuple(arg_shardings) arg_shardings = tuple(arg_shardings)
out_shardings = (out_sharding, softmax_aux_sharding, rng_state_sharding) out_shardings = (out_sharding, softmax_aux_sharding, rng_state_sharding)
...@@ -1991,7 +1994,13 @@ class FusedRingAttnBwdPrimitive(FusedAttnBwdPrimitive): ...@@ -1991,7 +1994,13 @@ class FusedRingAttnBwdPrimitive(FusedAttnBwdPrimitive):
dk_sharding = NamedSharding(mesh, PartitionSpec(*k_spec)) dk_sharding = NamedSharding(mesh, PartitionSpec(*k_spec))
dv_sharding = NamedSharding(mesh, PartitionSpec(*v_spec)) dv_sharding = NamedSharding(mesh, PartitionSpec(*v_spec))
dbias_sharding = NamedSharding(mesh, PartitionSpec(*bias_spec)) dbias_sharding = NamedSharding(mesh, PartitionSpec(*bias_spec))
arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
arg_shardings = [arg_i.sharding for arg_i in arg_infos]
# Ensure segment_pos gets same sharding as ID.
arg_shardings[-1] = arg_shardings[-3]
arg_shardings[-2] = arg_shardings[-4]
arg_shardings = tuple(arg_shardings)
out_shardings = (dq_sharding, dk_sharding, dv_sharding, dbias_sharding) out_shardings = (dq_sharding, dk_sharding, dv_sharding, dbias_sharding)
helper = _FusedAttnCPWithP2PHelper(mesh, config) helper = _FusedAttnCPWithP2PHelper(mesh, config)
...@@ -2265,6 +2274,9 @@ class FusedRingAttnStripedFwdPrimitive(FusedAttnFwdPrimitive): ...@@ -2265,6 +2274,9 @@ class FusedRingAttnStripedFwdPrimitive(FusedAttnFwdPrimitive):
) )
arg_shardings = [arg_i.sharding for arg_i in arg_infos] arg_shardings = [arg_i.sharding for arg_i in arg_infos]
arg_shardings[4] = seed_sharding arg_shardings[4] = seed_sharding
# Ensure segment_pos gets same sharding as ID.
arg_shardings[-1] = arg_shardings[-3]
arg_shardings[-2] = arg_shardings[-4]
arg_shardings = tuple(arg_shardings) arg_shardings = tuple(arg_shardings)
out_shardings = (out_sharding, softmax_aux_sharding, rng_state_sharding) out_shardings = (out_sharding, softmax_aux_sharding, rng_state_sharding)
...@@ -2403,7 +2415,11 @@ class FusedRingAttnStripedBwdPrimitive(FusedAttnBwdPrimitive): ...@@ -2403,7 +2415,11 @@ class FusedRingAttnStripedBwdPrimitive(FusedAttnBwdPrimitive):
if not is_context_parallel: if not is_context_parallel:
return FusedAttnBwdPrimitive.partition(config, mesh, arg_infos, result_infos) return FusedAttnBwdPrimitive.partition(config, mesh, arg_infos, result_infos)
arg_shardings = tuple(arg.sharding for arg in arg_infos) arg_shardings = [arg_i.sharding for arg_i in arg_infos]
# Ensure segment_pos gets same sharding as ID.
arg_shardings[-1] = arg_shardings[-3]
arg_shardings[-2] = arg_shardings[-4]
arg_shardings = tuple(arg_shardings)
# dq, dk, dv, dbias sharding = q, k, v, bias sharding # dq, dk, dv, dbias sharding = q, k, v, bias sharding
out_shardings = tuple(arg.sharding for arg in arg_infos[:4]) out_shardings = tuple(arg.sharding for arg in arg_infos[:4])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment