Unverified Commit 4296b7d0 authored by Charlene Yang's avatar Charlene Yang Committed by GitHub
Browse files

Fix the device for cuDNN/cuBLAS handles (#1974)



* fix current device for cuDNN/cuBLAS handles
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* add unit test
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* use weight device and improve tests
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent fdb87afc
......@@ -23,6 +23,7 @@ mkdir -p "$XML_LOG_DIR"
pip3 install pytest==8.2.1 || error_exit "Failed to install pytest"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_sanity.xml $TE_PATH/tests/pytorch/distributed/test_sanity.py || test_fail "test_sanity.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_numerics.xml $TE_PATH/tests/pytorch/distributed/test_numerics.py || test_fail "test_numerics.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fusible_ops.xml $TE_PATH/tests/pytorch/distributed/test_fusible_ops.py || test_fail "test_fusible_ops.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_torch_fsdp2.xml $TE_PATH/tests/pytorch/distributed/test_torch_fsdp2.py || test_fail "test_torch_fsdp2.py"
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import pathlib
import sys
import pytest
import torch
import transformer_engine
from transformer_engine.pytorch.attention.dot_product_attention import DotProductAttention
from transformer_engine.pytorch import TransformerLayer, Linear
_current_file = pathlib.Path(__file__).resolve()
sys.path.append(str(_current_file.parent.parent))
from utils import ModelConfig
model_configs = {
"small": ModelConfig(2, 10, 2, 16),
}
@pytest.mark.parametrize("model", ["small"])
@pytest.mark.parametrize("module", ["TransformerLayer", "DotProductAttention", "Linear"])
def test_current_device(model, module):
"""Test cases where current device is different from tensor device"""
num_devices = torch.cuda.device_count()
assert num_devices > 1, "This test requires more than one GPU!"
tensor_device = num_devices - 1
dtype = torch.bfloat16
config = model_configs[model]
args = []
kwargs = {}
bwd_args = []
if module == "TransformerLayer":
model = TransformerLayer(
config.hidden_size,
4 * config.hidden_size,
config.num_heads,
params_dtype=dtype,
attn_input_format="thd",
self_attn_mask_type="padding",
device=f"cuda:{tensor_device}",
)
num_tokens = torch.randint(0, config.max_seqlen_q, (1,)).item()
args = [
torch.randn(
(num_tokens, config.hidden_size),
dtype=dtype,
device=f"cuda:{tensor_device}",
requires_grad=True,
)
]
cu_seqlens_q, cu_seqlens_kv = [
torch.Tensor([0, 2, 3]).to(dtype=torch.int32, device=tensor_device) for _ in range(2)
]
kwargs["cu_seqlens_q"] = cu_seqlens_q
kwargs["cu_seqlens_kv"] = cu_seqlens_kv
kwargs["max_seqlen_q"] = config.max_seqlen_q
kwargs["max_seqlen_kv"] = config.max_seqlen_kv
if module == "DotProductAttention":
model = DotProductAttention(
config.num_heads, config.head_dim_qk, qkv_format="thd", attn_mask_type="padding"
)
num_tokens = torch.randint(0, config.max_seqlen_q, (1,)).item()
args = [
torch.randn(
num_tokens,
config.num_heads,
config.head_dim_qk,
dtype=dtype,
device=tensor_device,
requires_grad=True,
)
for _ in range(3)
]
cu_seqlens_q, cu_seqlens_kv = [
torch.Tensor([0, 2, 3]).to(dtype=torch.int32, device=tensor_device) for _ in range(2)
]
kwargs["cu_seqlens_q"] = cu_seqlens_q
kwargs["cu_seqlens_kv"] = cu_seqlens_kv
kwargs["max_seqlen_q"] = config.max_seqlen_q
kwargs["max_seqlen_kv"] = config.max_seqlen_kv
bwd_args = [torch.randn(num_tokens, config.hidden_size, dtype=dtype, device=tensor_device)]
elif module == "Linear":
model = Linear(
config.hidden_size,
4 * config.hidden_size,
params_dtype=dtype,
device=f"cuda:{tensor_device}",
)
args = [
torch.randn(
(config.max_seqlen_q, config.batch_size, config.hidden_size),
dtype=dtype,
device=f"cuda:{tensor_device}",
requires_grad=True,
)
]
current_device_before = torch.cuda.current_device()
out = model(*args, **kwargs)
if module == "DotProductAttention":
out.backward(*bwd_args)
else:
loss = out.sum()
loss.backward()
current_device_after = torch.cuda.current_device()
tensor_device_out = out.get_device()
tensor_device_grad = args[0].grad.get_device()
assert (
current_device_after == current_device_before
), "The current device should not have changed!"
assert (
tensor_device_out == tensor_device
), "The output tensor should be the same as the input tensors!"
assert (
tensor_device_grad == tensor_device
), "The gradient tensor should be the same as the input tensors!"
......@@ -630,7 +630,7 @@ class DotProductAttention(TransformerEngineBaseModule):
If true, there are padding tokens between individual sequences in a packed batch.
"""
with self.prepare_forward(
with torch.cuda.device(query_layer.device), self.prepare_forward(
query_layer,
num_gemms=3,
allow_non_contiguous=True,
......
......@@ -742,7 +742,9 @@ class GroupedLinear(TransformerEngineBaseModule):
if skip_fp8_weight_update is not None:
is_first_microbatch = False
with self.prepare_forward(inp, num_gemms=self.num_gemms) as inp:
with torch.cuda.device(
getattr(self, list(self.named_parameters())[0][0]).device
), self.prepare_forward(inp, num_gemms=self.num_gemms) as inp:
weight_tensors = self._get_weight_tensors()
bias_tensors = [getattr(self, f"bias{i}") for i in range(self.num_gemms)]
......
......@@ -1484,7 +1484,9 @@ class LayerNormLinear(TransformerEngineBaseModule):
if get_ub(self.ub_name + "_dgrad").is_fp8_ubuf():
fp8_grad = True
with self.prepare_forward(
with torch.cuda.device(
getattr(self, list(self.named_parameters())[0][0]).device
), self.prepare_forward(
inp, allow_non_contiguous=False # removed .contiguous from inside the layer
) as inp:
......
......@@ -1740,7 +1740,9 @@ class LayerNormMLP(TransformerEngineBaseModule):
if get_ub("fc2_fprop").is_fp8_ubuf():
fp8_output = True
with self.prepare_forward(inp, num_gemms=2) as inp:
with torch.cuda.device(
getattr(self, list(self.named_parameters())[0][0]).device
), self.prepare_forward(inp, num_gemms=2) as inp:
quantizers = (
self._get_quantizers(fp8_output)
......
......@@ -1353,7 +1353,9 @@ class Linear(TransformerEngineBaseModule):
if get_ub(self.ub_name + "_dgrad").is_fp8_ubuf():
fp8_grad = True
with self.prepare_forward(
with torch.cuda.device(
getattr(self, list(self.named_parameters())[0][0]).device
), self.prepare_forward(
inp,
allow_non_contiguous=isinstance(inp, QuantizedTensor),
) as inp:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment