Unverified Commit 41fb9bcf authored by Charlene Yang's avatar Charlene Yang Committed by GitHub
Browse files

[PyTorch] fix `test_current_device` test (#2398)



* fix test_current_device
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 05bfa3f8
...@@ -55,7 +55,29 @@ def test_current_device(model, module): ...@@ -55,7 +55,29 @@ def test_current_device(model, module):
self_attn_mask_type="padding", self_attn_mask_type="padding",
device=f"cuda:{tensor_device}", device=f"cuda:{tensor_device}",
) )
num_tokens = torch.randint(0, config.max_seqlen_q, (1,)).item() seqlens_q = torch.randint(
1,
config.max_seqlen_q,
[config.batch_size],
dtype=torch.int32,
device=f"cuda:{tensor_device}",
)
cu_seqlens_q = torch.zeros(
config.batch_size + 1, dtype=torch.int32, device=f"cuda:{tensor_device}"
)
cu_seqlens_q[1:] = torch.cumsum(seqlens_q, dim=0)
seqlens_kv = torch.randint(
1,
config.max_seqlen_kv,
[config.batch_size],
dtype=torch.int32,
device=f"cuda:{tensor_device}",
)
cu_seqlens_kv = torch.zeros(
config.batch_size + 1, dtype=torch.int32, device=f"cuda:{tensor_device}"
)
cu_seqlens_kv[1:] = torch.cumsum(seqlens_kv, dim=0)
num_tokens = cu_seqlens_q[-1]
args = [ args = [
torch.randn( torch.randn(
(num_tokens, config.hidden_size), (num_tokens, config.hidden_size),
...@@ -64,9 +86,6 @@ def test_current_device(model, module): ...@@ -64,9 +86,6 @@ def test_current_device(model, module):
requires_grad=True, requires_grad=True,
) )
] ]
cu_seqlens_q, cu_seqlens_kv = [
torch.Tensor([0, 2, 3]).to(dtype=torch.int32, device=tensor_device) for _ in range(2)
]
kwargs["cu_seqlens_q"] = cu_seqlens_q kwargs["cu_seqlens_q"] = cu_seqlens_q
kwargs["cu_seqlens_kv"] = cu_seqlens_kv kwargs["cu_seqlens_kv"] = cu_seqlens_kv
kwargs["max_seqlen_q"] = config.max_seqlen_q kwargs["max_seqlen_q"] = config.max_seqlen_q
...@@ -75,26 +94,47 @@ def test_current_device(model, module): ...@@ -75,26 +94,47 @@ def test_current_device(model, module):
model = DotProductAttention( model = DotProductAttention(
config.num_heads, config.head_dim_qk, qkv_format="thd", attn_mask_type="padding" config.num_heads, config.head_dim_qk, qkv_format="thd", attn_mask_type="padding"
) )
num_tokens = torch.randint(0, config.max_seqlen_q, (1,)).item() seqlens_q = torch.randint(
1,
config.max_seqlen_q,
[config.batch_size],
dtype=torch.int32,
device=f"cuda:{tensor_device}",
)
cu_seqlens_q = torch.zeros(
config.batch_size + 1, dtype=torch.int32, device=f"cuda:{tensor_device}"
)
cu_seqlens_q[1:] = torch.cumsum(seqlens_q, dim=0)
seqlens_kv = torch.randint(
1,
config.max_seqlen_kv,
[config.batch_size],
dtype=torch.int32,
device=f"cuda:{tensor_device}",
)
cu_seqlens_kv = torch.zeros(
config.batch_size + 1, dtype=torch.int32, device=f"cuda:{tensor_device}"
)
cu_seqlens_kv[1:] = torch.cumsum(seqlens_kv, dim=0)
num_tokens = cu_seqlens_q[-1]
args = [ args = [
torch.randn( torch.randn(
num_tokens, num_tokens,
config.num_heads, config.num_heads,
config.head_dim_qk, config.head_dim_qk,
dtype=dtype, dtype=dtype,
device=tensor_device, device=f"cuda:{tensor_device}",
requires_grad=True, requires_grad=True,
) )
for _ in range(3) for _ in range(3)
] ]
cu_seqlens_q, cu_seqlens_kv = [
torch.Tensor([0, 2, 3]).to(dtype=torch.int32, device=tensor_device) for _ in range(2)
]
kwargs["cu_seqlens_q"] = cu_seqlens_q kwargs["cu_seqlens_q"] = cu_seqlens_q
kwargs["cu_seqlens_kv"] = cu_seqlens_kv kwargs["cu_seqlens_kv"] = cu_seqlens_kv
kwargs["max_seqlen_q"] = config.max_seqlen_q kwargs["max_seqlen_q"] = config.max_seqlen_q
kwargs["max_seqlen_kv"] = config.max_seqlen_kv kwargs["max_seqlen_kv"] = config.max_seqlen_kv
bwd_args = [torch.randn(num_tokens, config.hidden_size, dtype=dtype, device=tensor_device)] bwd_args = [
torch.randn(num_tokens, config.hidden_size, dtype=dtype, device=f"cuda:{tensor_device}")
]
elif module == "Linear": elif module == "Linear":
model = Linear( model = Linear(
config.hidden_size, config.hidden_size,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment