Unverified Commit 8833a8d0 authored by Charlene Yang's avatar Charlene Yang Committed by GitHub
Browse files

[PyTorch] Reduce the amount of roundup for max_seqlen in THD (#1079)



reduce the roundup of max_seqlen for THD to multiples of 64
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>
parent 121ff62a
...@@ -5725,13 +5725,13 @@ class DotProductAttention(TransformerEngineBaseModule): ...@@ -5725,13 +5725,13 @@ class DotProductAttention(TransformerEngineBaseModule):
seqlens_q = cu_seqlens_q_padded[1:] - cu_seqlens_q_padded[:-1] seqlens_q = cu_seqlens_q_padded[1:] - cu_seqlens_q_padded[:-1]
else: else:
seqlens_q = cu_seqlens_q[1:] - cu_seqlens_q[:-1] seqlens_q = cu_seqlens_q[1:] - cu_seqlens_q[:-1]
max_seqlen_q = pow(2, math.ceil(math.log2(seqlens_q.max().item()))) max_seqlen_q = int((seqlens_q.max().item() + 63) // 64 * 64)
if max_seqlen_kv is None: if max_seqlen_kv is None:
if cu_seqlens_kv_padded is not None: if cu_seqlens_kv_padded is not None:
seqlens_kv = cu_seqlens_kv_padded[1:] - cu_seqlens_kv_padded[:-1] seqlens_kv = cu_seqlens_kv_padded[1:] - cu_seqlens_kv_padded[:-1]
else: else:
seqlens_kv = cu_seqlens_kv[1:] - cu_seqlens_kv[:-1] seqlens_kv = cu_seqlens_kv[1:] - cu_seqlens_kv[:-1]
max_seqlen_kv = pow(2, math.ceil(math.log2(seqlens_kv.max().item()))) max_seqlen_kv = int((seqlens_kv.max().item() + 63) // 64 * 64)
batch_size = len(cu_seqlens_q) - 1 batch_size = len(cu_seqlens_q) - 1
cp_size = 1 if self.cp_group is None else get_distributed_world_size(self.cp_group) cp_size = 1 if self.cp_group is None else get_distributed_world_size(self.cp_group)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment