Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
9d0f1c9b
Commit
9d0f1c9b
authored
May 08, 2025
by
yuguo
Browse files
[DCU] add batchgemm test
parent
e8f92b93
Changes
9
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
664 additions
and
445 deletions
+664
-445
qa/L0_pytorch_unittest/test.sh
qa/L0_pytorch_unittest/test.sh
+1
-0
setup.py
setup.py
+1
-1
tests/pytorch/test_batched_linear.py
tests/pytorch/test_batched_linear.py
+295
-0
transformer_engine/common/gemm/cublaslt_gemm.cu
transformer_engine/common/gemm/cublaslt_gemm.cu
+3
-0
transformer_engine/pytorch/cpp_extensions/gemm.py
transformer_engine/pytorch/cpp_extensions/gemm.py
+42
-48
transformer_engine/pytorch/csrc/extensions.h
transformer_engine/pytorch/csrc/extensions.h
+17
-7
transformer_engine/pytorch/csrc/extensions/gemm.cpp
transformer_engine/pytorch/csrc/extensions/gemm.cpp
+87
-106
transformer_engine/pytorch/csrc/extensions/pybind.cpp
transformer_engine/pytorch/csrc/extensions/pybind.cpp
+1
-1
transformer_engine/pytorch/module/batched_linear.py
transformer_engine/pytorch/module/batched_linear.py
+217
-282
No files found.
qa/L0_pytorch_unittest/test.sh
View file @
9d0f1c9b
...
...
@@ -26,6 +26,7 @@ python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_sanity.py || test_fail "test
python3
-m
pytest
-v
-s
$TE_PATH
/tests/pytorch/test_recipe.py
||
test_fail
"test_recipe.py"
python3
-m
pytest
-v
-s
$TE_PATH
/tests/pytorch/test_deferred_init.py
||
test_fail
"test_deferred_init.py"
PYTORCH_JIT
=
0
NVTE_TORCH_COMPILE
=
0
NVTE_ALLOW_NONDETERMINISTIC_ALGO
=
0 python3
-m
pytest
-v
-s
$TE_PATH
/tests/pytorch/test_numerics.py
||
test_fail
"test_numerics.py"
PYTORCH_JIT
=
0
NVTE_TORCH_COMPILE
=
0
NVTE_ALLOW_NONDETERMINISTIC_ALGO
=
0 python3
-m
pytest
-v
-s
$TE_PATH
/tests/pytorch/test_batched_linear.py
||
test_fail
"test_batched_linear.py"
PYTORCH_JIT
=
0
NVTE_TORCH_COMPILE
=
0
NVTE_ALLOW_NONDETERMINISTIC_ALGO
=
0 python3
-m
pytest
-v
-s
$TE_PATH
/tests/pytorch/test_cuda_graphs.py
||
test_fail
"test_cuda_graphs.py"
python3
-m
pytest
-v
-s
$TE_PATH
/tests/pytorch/test_jit.py
||
test_fail
"test_jit.py"
python3
-m
pytest
-v
-s
$TE_PATH
/tests/pytorch/test_fused_rope.py
||
test_fail
"test_fused_rope.py"
...
...
setup.py
View file @
9d0f1c9b
...
...
@@ -4,7 +4,7 @@
"""Installation script."""
# NVTE_FRAMEWORK=pytorch NVTE_USE_ROCM=1 NVTE_USE_HIPBLASLT=1 NVTE_USE_ROCBLAS=1 CMAKE_PREFIX_PATH=/opt/dtk/lib/cmake/amd_comgr/ MPI_HOME=/opt/mpi/ NVTE_UB_WITH_MPI=1 CXX=hipcc pip3 install . -v
# VTE_FRAMEWORK=pytorch NVTE_USE_ROCM=1 NVTE_USE_HIPBLASLT=1 NVTE_USE_ROCBLAS=1 CMAKE_PREFIX_PATH=/opt/dtk/lib/cmake/amd_comgr/ MPI_HOME=/opt/mpi/ NVTE_UB_WITH_MPI=1 CXX=hipcc PYTHONPATH=/home/TransformerEngine/3rdparty/hipify_torch:$PYTHONPATH python3 setup.py bdist_wheel
#
N
VTE_FRAMEWORK=pytorch NVTE_USE_ROCM=1 NVTE_USE_HIPBLASLT=1 NVTE_USE_ROCBLAS=1 CMAKE_PREFIX_PATH=/opt/dtk/lib/cmake/amd_comgr/ MPI_HOME=/opt/mpi/ NVTE_UB_WITH_MPI=1 CXX=hipcc PYTHONPATH=/home/TransformerEngine/3rdparty/hipify_torch:$PYTHONPATH python3 setup.py bdist_wheel
import
os
import
sys
...
...
tests/pytorch/test_batched_linear.py
0 → 100644
View file @
9d0f1c9b
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
from
collections
import
OrderedDict
import
math
import
os
from
typing
import
Dict
,
List
,
Tuple
,
Optional
import
pytest
import
copy
import
random
import
torch
import
torch.nn
as
nn
from
torch.nn
import
Parameter
from
torch.utils.cpp_extension
import
IS_HIP_EXTENSION
from
transformer_engine.pytorch.fp8
import
(
FP8GlobalStateManager
,
fp8_autocast
,
fp8_model_init
,
)
from
transformer_engine.pytorch.utils
import
(
init_method_normal
,
scaled_init_method_normal
,
attention_mask_func
,
is_bf16_compatible
,
)
from
transformer_engine.pytorch
import
(
DotProductAttention
,
LayerNormLinear
,
LayerNormMLP
,
Linear
,
GroupedLinear
,
BatchedLinear
,
MultiheadAttention
,
RMSNorm
,
TransformerLayer
,
LayerNorm
,
Fp8Padding
,
Fp8Unpadding
,
)
from
transformer_engine.pytorch
import
torch_version
from
transformer_engine.pytorch.dot_product_attention.inference
import
InferenceParams
from
transformer_engine.pytorch.distributed
import
checkpoint
as
te_checkpoint
from
transformer_engine.pytorch.cpp_extensions
import
general_gemm
,
general_grouped_gemm
,
batchgemm
from
transformer_engine.pytorch.tensor.float8_tensor
import
Float8Quantizer
from
transformer_engine.pytorch.module.base
import
get_multi_stream_cublas_workspace
,
get_workspace
from
transformer_engine.pytorch.utils
import
get_device_compute_capability
from
transformer_engine.common
import
recipe
import
transformer_engine_torch
as
tex
# Only run FP8 tests on supported devices.
fp8_available
,
reason_for_no_fp8
=
FP8GlobalStateManager
.
is_fp8_available
()
mxfp8_available
,
reason_for_no_mxfp8
=
FP8GlobalStateManager
.
is_mxfp8_available
()
sm_80plus
=
get_device_compute_capability
()
>=
(
8
,
0
)
seed
=
1234
torch
.
manual_seed
(
seed
)
torch
.
cuda
.
manual_seed
(
seed
)
# Record initial RNG state from script run.
_cpu_rng_state
=
torch
.
get_rng_state
()
_cuda_rng_state
=
torch
.
cuda
.
get_rng_state
()
if
torch_version
()
>=
(
2
,
7
,
0
):
torch
.
_dynamo
.
config
.
recompile_limit
=
16
else
:
torch
.
_dynamo
.
config
.
cache_size_limit
=
16
class
ModelConfig
:
def
__init__
(
self
,
hidden_size
,
eps
,
num_attention_heads
,
embed
,
num_layers
,
seq_len
):
self
.
hidden_size
=
hidden_size
self
.
eps
=
eps
self
.
num_attention_heads
=
num_attention_heads
self
.
embed
=
embed
self
.
num_layers
=
num_layers
self
.
seq_len
=
seq_len
model_configs
=
{
"small"
:
ModelConfig
(
128
,
1e-5
,
8
,
36
,
4
,
128
),
"126m"
:
ModelConfig
(
768
,
1e-5
,
12
,
64
,
12
,
2048
),
}
model_configs_inference
=
{
# hidden_size, eps, num_attention_heads, embed, num_layers, seq_len
"126m"
:
ModelConfig
(
768
,
1e-5
,
12
,
64
,
12
,
256
),
}
backends_inference
=
[
"FlashAttention"
,
"UnfusedAttention"
,
"FusedAttention"
]
module_inference
=
[
"TransformerLayer"
,
"MultiheadAttention"
]
input_formats_inference
=
[
"sbhd"
,
"bshd"
]
param_types
=
[
torch
.
float32
,
torch
.
float16
]
if
is_bf16_compatible
():
# bf16 requires sm_80 or higher
param_types
.
append
(
torch
.
bfloat16
)
batch_sizes
=
[
1
,
2
]
all_boolean
=
[
True
,
False
]
all_activations
=
[
"gelu"
,
"relu"
,
"reglu"
,
"geglu"
,
"swiglu"
,
"qgelu"
,
"srelu"
]
all_normalizations
=
[
"LayerNorm"
,
"RMSNorm"
]
mask_types
=
[
"causal"
,
"no_mask"
]
fp8_recipes
=
[
recipe
.
MXFP8BlockScaling
(),
recipe
.
DelayedScaling
(),
recipe
.
Float8CurrentScaling
(),
]
def
get_causal_attn_mask
(
sq
:
int
)
->
torch
.
Tensor
:
return
torch
.
triu
(
torch
.
ones
(
sq
,
sq
,
device
=
"cuda"
),
diagonal
=
1
).
bool
()
def
dtype_tols
(
dtype
:
torch
.
dtype
)
->
Dict
[
str
,
float
]:
"""Estimated numerical error for a datatype
Based on tolerances for torch.testing.assert_close.
"""
if
dtype
==
torch
.
float32
:
return
dict
(
rtol
=
1.3e-6
,
atol
=
1e-5
)
if
dtype
==
torch
.
float16
:
return
dict
(
rtol
=
1e-3
,
atol
=
1e-5
)
if
dtype
==
torch
.
bfloat16
:
return
dict
(
rtol
=
1.6e-2
,
atol
=
1e-5
)
raise
ValueError
(
f
"Unsuppored dtype (
{
dtype
}
)"
)
def
assert_allclose
(
l1
:
List
[
torch
.
Tensor
],
l2
:
List
[
torch
.
Tensor
],
atol
:
float
,
rtol
:
float
=
None
)
->
bool
:
"""Ensures two lists are equal."""
assert
len
(
l1
)
==
len
(
l2
),
"Unequal number of outputs."
for
i
,
(
t1
,
t2
)
in
enumerate
(
zip
(
l1
,
l2
)):
tols
=
dict
(
atol
=
atol
)
if
rtol
is
not
None
:
tols
[
"rtol"
]
=
rtol
result
=
torch
.
allclose
(
t1
,
t2
,
**
tols
)
if
not
result
:
diff
=
torch
.
abs
(
t1
-
t2
)
tol
=
atol
+
(
rtol
*
torch
.
abs
(
t2
))
exceed_mask
=
diff
>
tol
if
exceed_mask
.
any
():
indices
=
torch
.
nonzero
(
exceed_mask
,
as_tuple
=
True
)
max_diff
=
diff
[
exceed_mask
].
max
()
max_idx
=
(
diff
[
exceed_mask
]
==
max_diff
).
nonzero
(
as_tuple
=
True
)[
0
][
0
]
max_location
=
[
idx
[
max_idx
].
item
()
for
idx
in
indices
]
msg
=
(
f
"Outputs not close enough in tensor at idx=
{
i
}
. "
f
"Maximum difference at location
{
max_location
}
"
f
"with
{
t1
[
exceed_mask
][
max_idx
].
item
()
}
vs
{
t2
[
exceed_mask
][
max_idx
].
item
()
}
"
f
"(diff
{
max_diff
.
item
()
}
)."
)
raise
AssertionError
(
msg
)
def
reset_rng_states
()
->
None
:
"""revert back to initial RNG state."""
torch
.
set_rng_state
(
_cpu_rng_state
)
torch
.
cuda
.
set_rng_state
(
_cuda_rng_state
)
@
pytest
.
fixture
(
autouse
=
True
)
def
reset_global_fp8_state
():
yield
FP8GlobalStateManager
.
reset
()
def
_test_batched_linear_accuracy
(
block
,
num_gemms
,
bs
,
dtype
,
config
,
recipe
,
fp8
,
fuse_wgrad_accumulation
):
reset_rng_states
()
if
fp8
:
FP8GlobalStateManager
.
reset
()
inp_hidden_states
=
torch
.
randn
(
(
config
.
seq_len
,
bs
,
config
.
hidden_size
),
dtype
=
dtype
,
device
=
"cuda"
,
requires_grad
=
True
,
)
inp_hidden_states
.
retain_grad
()
assert
config
.
seq_len
%
num_gemms
==
0
m_splits
=
torch
.
tensor
([
config
.
seq_len
//
num_gemms
for
i
in
range
(
num_gemms
)])
assert
m_splits
.
sum
()
==
config
.
seq_len
and
len
(
m_splits
)
==
num_gemms
with
fp8_autocast
(
enabled
=
fp8
,
fp8_recipe
=
recipe
):
if
isinstance
(
block
,
BatchedLinear
):
m_splits
=
m_splits
*
bs
out
=
block
(
inp_hidden_states
,
m_splits
.
tolist
())
else
:
out
=
torch
.
cat
(
[
block
[
i
](
inp
)
for
i
,
inp
in
enumerate
(
torch
.
split
(
inp_hidden_states
,
m_splits
.
tolist
()))
]
)
loss
=
out
.
sum
()
loss
.
backward
()
torch
.
cuda
.
synchronize
()
outputs
=
[
out
,
inp_hidden_states
.
grad
]
return
outputs
@
pytest
.
mark
.
parametrize
(
"dtype"
,
param_types
)
@
pytest
.
mark
.
parametrize
(
"num_gemms"
,
[
4
,
8
])
@
pytest
.
mark
.
parametrize
(
"bs"
,
batch_sizes
)
@
pytest
.
mark
.
parametrize
(
"model"
,
[
"126m"
])
@
pytest
.
mark
.
parametrize
(
"fp8"
,
all_boolean
)
@
pytest
.
mark
.
parametrize
(
"recipe"
,
fp8_recipes
)
@
pytest
.
mark
.
parametrize
(
"fp8_model_params"
,
all_boolean
)
@
pytest
.
mark
.
parametrize
(
"fuse_wgrad_accumulation"
,
all_boolean
)
def
test_batched_linear_accuracy
(
dtype
,
num_gemms
,
bs
,
model
,
fp8
,
recipe
,
fp8_model_params
,
fuse_wgrad_accumulation
,
parallel_mode
=
None
,
):
batch_num
=
int
(
os
.
getenv
(
"NVTE_MOE_BATCHCOUNT"
,
"2"
))
if
fp8
and
not
fp8_available
:
pytest
.
skip
(
reason_for_no_fp8
)
if
recipe
.
mxfp8
()
and
not
mxfp8_available
:
pytest
.
skip
(
reason_for_no_mxfp8
)
if
fp8
and
recipe
.
mxfp8
():
# TODO(ksivamani): debug mismatches
pytest
.
skip
(
"MXFP8 unsupported for batched linear."
)
if
fp8
and
recipe
.
float8_current_scaling
():
pytest
.
skip
(
"Float8 Current Scaling unsupported for batched linear."
)
config
=
model_configs
[
model
]
if
config
.
seq_len
%
16
!=
0
and
fp8
:
pytest
.
skip
(
"FP8 requires sequence length to be divisible by 16."
)
with
fp8_model_init
(
enabled
=
fp8
and
fp8_model_params
,
recipe
=
recipe
):
batched_linear
=
BatchedLinear
(
num_gemms
,
config
.
hidden_size
,
4
*
config
.
hidden_size
,
bias
=
False
,
params_dtype
=
dtype
,
parallel_mode
=
parallel_mode
,
device
=
"cuda"
,
fuse_wgrad_accumulation
=
fuse_wgrad_accumulation
,
).
eval
()
sequential_linear
=
torch
.
nn
.
ModuleList
(
[
Linear
(
config
.
hidden_size
,
4
*
config
.
hidden_size
,
bias
=
False
,
params_dtype
=
dtype
,
parallel_mode
=
parallel_mode
,
device
=
"cuda"
,
fuse_wgrad_accumulation
=
fuse_wgrad_accumulation
,
).
eval
()
for
_
in
range
(
num_gemms
)
]
)
# Share params
with
torch
.
no_grad
():
for
i
in
range
(
num_gemms
//
batch_num
):
weight
=
getattr
(
batched_linear
,
f
"weight
{
i
}
"
).
clone
()
# bias = getattr(batched_linear, f"bias{i}").clone()
if
fuse_wgrad_accumulation
:
weight_i
=
getattr
(
batched_linear
,
f
"weight
{
i
}
"
)
weight_i
.
main_grad
=
torch
.
rand_like
(
weight_i
,
dtype
=
torch
.
float32
)
for
j
in
range
(
batch_num
):
sequential_linear
[
i
*
batch_num
+
j
].
weight
=
Parameter
(
weight
[
weight
.
shape
[
0
]
//
batch_num
*
j
:
weight
.
shape
[
0
]
//
batch_num
*
(
j
+
1
)].
clone
())
# sequential_linear[i * batch_num + j].bias = Parameter(bias[bias.shape[0] // batch_num * j : bias.shape[0] // batch_num * (j + 1)].clone())
if
fuse_wgrad_accumulation
:
sequential_linear
[
i
*
batch_num
+
j
].
weight
.
main_grad
=
weight_i
.
main_grad
[
weight_i
.
main_grad
.
shape
[
0
]
//
batch_num
*
j
:
weight_i
.
main_grad
.
shape
[
0
]
//
batch_num
*
(
j
+
1
)].
clone
()
outputs_ref
=
_test_batched_linear_accuracy
(
sequential_linear
,
num_gemms
,
bs
,
dtype
,
config
,
recipe
,
fp8
,
fuse_wgrad_accumulation
)
outputs
=
_test_batched_linear_accuracy
(
batched_linear
,
num_gemms
,
bs
,
dtype
,
config
,
recipe
,
fp8
,
fuse_wgrad_accumulation
)
# Shoule be bit-wise match
for
i
,
(
o
,
o_ref
)
in
enumerate
(
zip
(
outputs
,
outputs_ref
)):
torch
.
testing
.
assert_close
(
o
,
o_ref
,
rtol
=
6e-3
,
atol
=
6e-3
)
if
__name__
==
"__main__"
:
test_batched_linear_accuracy
(
torch
.
float32
,
2
,
1
,
"126m"
,
False
,
recipe
.
Float8CurrentScaling
(),
True
,
True
)
transformer_engine/common/gemm/cublaslt_gemm.cu
View file @
9d0f1c9b
...
...
@@ -778,6 +778,9 @@ void nvte_cublas_batchgemm(const NVTETensor A, const NVTETensor B, NVTETensor D,
const
Tensor
*
biasTensor
=
reinterpret_cast
<
const
Tensor
*>
(
bias
);
Tensor
*
outputGelu
=
reinterpret_cast
<
Tensor
*>
(
pre_gelu_out
);
Tensor
*
wspace
=
reinterpret_cast
<
Tensor
*>
(
workspace
);
if
((
biasTensor
->
data
.
dptr
!=
nullptr
)
||
(
outputGelu
->
data
.
dptr
!=
nullptr
))
{
NVTE_ERROR
(
"MOE batchgemm not surpport bias or gelu."
);
}
int
m
,
n
,
k
;
if
(
!
transa
&&
transb
)
{
...
...
transformer_engine/pytorch/cpp_extensions/gemm.py
View file @
9d0f1c9b
...
...
@@ -18,7 +18,7 @@ from ..tensor._internal.mxfp8_tensor_base import MXFP8TensorBase
__all__
=
[
"general_gemm"
,
"general_grouped_gemm"
,
"
general_
batch
ed_
gemm"
,
"batchgemm"
,
]
...
...
@@ -226,84 +226,78 @@ def general_grouped_gemm(
return
out
,
bias
,
gelu_input
def
general_
batch
ed_
gemm
(
def
batchgemm
(
A
:
List
[
torch
.
Tensor
],
B
:
List
[
torch
.
Tensor
],
out
:
List
[
torch
.
Tensor
],
out_
dtype
:
torch
.
dtype
,
dtype
:
torch
.
dtype
,
workspaces
:
List
[
torch
.
Tensor
],
layout
:
str
=
"TN"
,
m_splits
:
Optional
[
List
[
int
]]
=
None
,
gelu
:
bool
=
False
,
grad
=
False
,
gelu_input
:
Optional
[
List
[
torch
.
Tensor
]]
=
None
,
grad
:
bool
=
False
,
accumulate
:
bool
=
False
,
layout
:
str
=
"TN"
,
bias
:
Optional
[
List
[
torch
.
Tensor
]]
=
None
,
use_bias
:
bool
=
False
,
use_split_accumulator
:
bool
=
False
,
D_dtype
:
Optional
[
tex
.
DType
]
=
None
,
single_output
=
False
,
)
->
Tuple
[
List
[
torch
.
Tensor
],
...]:
"""
TN layout Grouped GEMM with fp8 inputs.
"""
num_gemms
=
len
(
A
)
)
->
Tuple
[
Union
[
torch
.
Tensor
,
None
],
...]:
"""Non FP8 batch GEMM."""
assert
layout
in
(
"TN"
,
"NN"
,
"NT"
),
f
"GEMM layout
{
layout
}
not supported."
transa
=
layout
[
0
]
==
"T"
transb
=
layout
[
1
]
==
"T"
num_gemms
=
len
(
A
)
empty_tensor
=
torch
.
Tensor
()
empty_tensors
=
[
torch
.
Tensor
()]
*
num_gemms
# assert [a.is_contiguous() for a in A]
# assert [b.is_contiguous() for b in B]
if
isinstance
(
A
[
0
],
Float8TensorBase
):
for
a
,
b
in
zip
(
A
,
B
):
assert_dim_for_fp8_exec
(
a
.
_data
)
assert_dim_for_fp8_exec
(
b
.
_data
)
empty_tensor
=
_empty_tensor
()
empty_tensors
=
[
empty_tensor
]
*
num_gemms
# Use bfloat16 as default bias_dtype
if
gelu
and
not
grad
:
gelu_input
=
[
torch
.
empty_like
(
o
,
dtype
=
dtype
,
memory_format
=
torch
.
contiguous_format
)
for
o
in
out
]
elif
not
gelu
:
gelu_input
=
empty_tensors
out_dtype
=
TE_DType
[
out
[
0
].
dtype
]
if
D_dtype
is
None
else
D_dtype
sm_count
=
get_sm_count
()
if
grad
and
use_bias
:
grad_bias
=
[
torch
.
empty
(
B
[
i
].
shape
[
1
],
dtype
=
out
[
0
].
dtype
,
device
=
"cuda"
)
for
i
in
range
(
num_gemms
)
]
else
:
grad_bias
=
empty_tensors
bias
=
bias
if
use_bias
else
empty_tensors
assert
(
A
[
0
].
dtype
==
dtype
and
B
[
0
].
dtype
==
dtype
),
f
"Expected dtype=
{
dtype
}
, but found A.dtype=
{
A
[
0
].
dtype
}
and B.dtype=
{
B
[
0
].
dtype
}
"
input_dtype
=
TE_DType
[
dtype
]
output_dtype
=
TE_DType
[
out
[
0
].
dtype
]
if
use_bias
:
bias_dtype
=
TE_DType
[
grad_bias
[
0
].
dtype
]
if
grad
else
TE_DType
[
bias
[
0
].
dtype
]
else
:
bias_dtype
=
TE_DType
[
torch
.
bfloat16
]
if
gelu
:
gelu_input
=
[
torch
.
empty_like
(
o
,
dtype
=
bias_dtype
,
memory_format
=
torch
.
contiguous_format
)
for
o
in
out
]
# this should differ with respect to single output
bias
=
tex
.
te_general_batched_gemm
(
bias_dtype
=
output_dtype
tex
.
te_batchgemm_ts
(
A
,
empty_tensor
,
0
,
# A_offset
input_dtype
,
transa
,
B
,
empty_tensor
,
0
,
# B_offset
input_dtype
,
transb
,
out
,
out_dtype
,
m_splits
,
0
,
# out_offset
empty_tensor
,
# out_scale
output_dtype
,
empty_tensor
,
# out_amax
grad_bias
if
grad
else
bias
,
bias_dtype
,
single_output
,
gelu_input
,
# this is pre_gelu_out
grad
,
# grad
gelu_input
,
# gelu_input
grad
,
workspaces
,
workspaces
[
0
].
shape
[
0
],
accumulate
,
use_split_accumulator
,
sm_count
-
int
(
os
.
getenv
(
"NVTE_EXT_MARGIN_SM"
,
str
(
sm_count
))),
False
,
# use_split_accumulator
)
return
out
,
bias
,
gelu_input
return
out
,
grad_bias
,
gelu_input
transformer_engine/pytorch/csrc/extensions.h
View file @
9d0f1c9b
...
...
@@ -102,13 +102,23 @@ std::optional<std::vector<at::Tensor>> te_general_grouped_gemm(
bool
use_split_accumulator
,
int
math_sm_count
);
#ifdef __HIP_PLATFORM_AMD__
std
::
optional
<
std
::
vector
<
at
::
Tensor
>>
te_general_batched_gemm
(
std
::
vector
<
py
::
handle
>
A
,
bool
transa
,
std
::
vector
<
py
::
handle
>
B
,
bool
transb
,
std
::
optional
<
std
::
vector
<
at
::
Tensor
>>
D
,
transformer_engine
::
DType
D_type
,
std
::
vector
<
int64_t
>
m_splits
,
std
::
vector
<
at
::
Tensor
>
bias
,
transformer_engine
::
DType
bias_type
,
bool
single_output
,
std
::
vector
<
at
::
Tensor
>
pre_gelu_out
,
bool
grad
,
std
::
vector
<
at
::
Tensor
>
workspace
,
size_t
workspaceSize
,
bool
accumulate
,
void
te_batchgemm
(
std
::
vector
<
at
::
Tensor
>
A
,
at
::
Tensor
A_scale_inverse
,
int
A_offset
,
transformer_engine
::
DType
A_type
,
bool
transa
,
std
::
vector
<
at
::
Tensor
>
B
,
at
::
Tensor
B_scale_inverse
,
int
B_offset
,
transformer_engine
::
DType
B_type
,
bool
transb
,
std
::
vector
<
at
::
Tensor
>
D
,
int
D_offset
,
at
::
Tensor
D_scale
,
transformer_engine
::
DType
D_type
,
at
::
Tensor
D_amax
,
std
::
vector
<
at
::
Tensor
>
bias
,
transformer_engine
::
DType
bias_type
,
std
::
vector
<
at
::
Tensor
>
pre_gelu_out
,
bool
grad
,
std
::
vector
<
at
::
Tensor
>
workspace
,
size_t
workspaceSize
,
bool
accumulate
,
bool
use_split_accumulator
,
int
math_sm_count
);
std
::
vector
<
at
::
Tensor
>
te_batchgemm_ts
(
std
::
vector
<
at
::
Tensor
>
A
,
at
::
Tensor
A_scale_inverse
,
int64_t
A_offset
,
int64_t
A_type
,
int64_t
transa
,
std
::
vector
<
at
::
Tensor
>
B
,
at
::
Tensor
B_scale_inverse
,
int64_t
B_offset
,
int64_t
B_type
,
int64_t
transb
,
std
::
vector
<
at
::
Tensor
>
D
,
int64_t
D_offset
,
at
::
Tensor
D_scale
,
int64_t
D_type
,
at
::
Tensor
D_amax
,
std
::
vector
<
at
::
Tensor
>
bias
,
int64_t
bias_type
,
std
::
vector
<
at
::
Tensor
>
pre_gelu_out
,
int64_t
grad
,
std
::
vector
<
at
::
Tensor
>
workspace
,
int64_t
workspaceSize
,
int64_t
accumulate
,
int64_t
use_split_accumulator
);
#endif
/***************************************************************************************************
...
...
transformer_engine/pytorch/csrc/extensions/gemm.cpp
View file @
9d0f1c9b
...
...
@@ -424,123 +424,104 @@ std::optional<std::vector<at::Tensor>> te_general_grouped_gemm(
}
#ifdef USE_ROCM
std
::
optional
<
std
::
vector
<
at
::
Tensor
>>
te_general_batched_gemm
(
std
::
vector
<
py
::
handle
>
A
,
bool
transa
,
std
::
vector
<
py
::
handle
>
B
,
bool
transb
,
std
::
optional
<
std
::
vector
<
at
::
Tensor
>>
D
,
transformer_engine
::
DType
D_type
,
std
::
vector
<
int64_t
>
m_splits
,
std
::
vector
<
at
::
Tensor
>
bias
,
transformer_engine
::
DType
bias_type
,
bool
single_output
,
std
::
vector
<
at
::
Tensor
>
pre_gelu_out
,
bool
grad
,
std
::
vector
<
at
::
Tensor
>
workspace
,
size_t
workspaceSize
,
bool
accumulate
,
void
te_batchgemm
(
std
::
vector
<
at
::
Tensor
>
A
,
at
::
Tensor
A_scale_inverse
,
int
A_offset
,
transformer_engine
::
DType
A_type
,
bool
transa
,
std
::
vector
<
at
::
Tensor
>
B
,
at
::
Tensor
B_scale_inverse
,
int
B_offset
,
transformer_engine
::
DType
B_type
,
bool
transb
,
std
::
vector
<
at
::
Tensor
>
D
,
int
D_offset
,
at
::
Tensor
D_scale
,
transformer_engine
::
DType
D_type
,
at
::
Tensor
D_amax
,
std
::
vector
<
at
::
Tensor
>
bias
,
transformer_engine
::
DType
bias_type
,
std
::
vector
<
at
::
Tensor
>
pre_gelu_out
,
bool
grad
,
std
::
vector
<
at
::
Tensor
>
workspace
,
size_t
workspaceSize
,
bool
accumulate
,
bool
use_split_accumulator
,
int
math_sm_count
)
{
using
namespace
transformer_engine
;
using
namespace
transformer_engine
::
pytorch
;
std
::
vector
<
NVTETensor
>
te_A_vector
,
te_B_vector
,
te_D_vector
,
te_bias_vector
,
te_pre_gelu_out_vector
,
te_workspace_vector
;
std
::
vector
<
TensorWrapper
>
wrappers
;
std
::
vector
<
at
::
Tensor
>
D_vectors
;
auto
none
=
py
::
none
();
std
::
vector
<
size_t
>
single_output_begins
;
std
::
vector
<
size_t
>
single_output_ends
;
int
slicing_dim
;
if
(
single_output
&&
D
==
std
::
nullopt
)
{
NVTE_ERROR
(
"not implemented, D should be allocated for single output case."
);
std
::
vector
<
NVTETensor
>
te_A
,
te_B
,
te_D
,
te_bias
,
te_pre_gelu_out
,
te_workspace
;
std
::
vector
<
transformer_engine
::
TensorWrapper
>
tensor_wrappers
;
auto
make_tensor
=
[
&
tensor_wrappers
](
void
*
dptr
,
const
std
::
vector
<
size_t
>&
shape
,
transformer_engine
::
DType
dtype
,
void
*
amax_dptr
,
void
*
scale_dptr
,
void
*
scale_inv_dptr
)
->
NVTETensor
{
tensor_wrappers
.
emplace_back
(
makeTransformerEngineTensor
(
dptr
,
shape
,
dtype
,
amax_dptr
,
scale_dptr
,
scale_inv_dptr
));
return
tensor_wrappers
.
back
().
data
();
};
for
(
size_t
i
=
0
;
i
<
A
.
size
();
i
++
)
{
if
(
A
[
i
].
data_ptr
()
==
nullptr
||
B
[
i
].
data_ptr
()
==
nullptr
)
{
if
(
D
[
i
].
data_ptr
()
!=
nullptr
&&
!
accumulate
)
D
[
i
].
zero_
();
if
(
bias
[
i
].
data_ptr
()
!=
nullptr
)
bias
[
i
].
zero_
();
if
(
pre_gelu_out
[
i
].
data_ptr
()
!=
nullptr
)
pre_gelu_out
[
i
].
zero_
();
continue
;
}
te_A
.
emplace_back
(
make_tensor
(
A
[
i
].
data_ptr
(),
{
static_cast
<
size_t
>
(
A
[
i
].
size
(
0
)),
static_cast
<
size_t
>
(
A
[
i
].
size
(
1
))},
A_type
,
nullptr
,
nullptr
,
getDataPtr
(
A_scale_inverse
,
A_offset
+
i
)));
te_B
.
emplace_back
(
make_tensor
(
B
[
i
].
data_ptr
(),
{
static_cast
<
size_t
>
(
B
[
i
].
size
(
0
)),
static_cast
<
size_t
>
(
B
[
i
].
size
(
1
))},
B_type
,
nullptr
,
nullptr
,
getDataPtr
(
B_scale_inverse
,
B_offset
+
i
)));
te_D
.
emplace_back
(
make_tensor
(
D
[
i
].
data_ptr
(),
{
static_cast
<
size_t
>
(
D
[
i
].
size
(
0
)),
static_cast
<
size_t
>
(
D
[
i
].
size
(
1
))},
D_type
,
getDataPtr
(
D_amax
,
D_offset
+
i
),
getDataPtr
(
D_scale
,
D_offset
+
i
),
nullptr
));
te_bias
.
emplace_back
(
make_tensor
(
bias
[
i
].
data_ptr
(),
{
static_cast
<
size_t
>
(
bias
[
i
].
size
(
0
))},
bias_type
,
nullptr
,
nullptr
,
nullptr
));
void
*
output_data_ptr
;
if
(
single_output
)
{
output_data_ptr
=
(
*
D
)[
0
].
data_ptr
();
const
auto
gelu_shape
=
pre_gelu_out
[
i
].
data_ptr
()
==
nullptr
?
std
::
vector
<
size_t
>
{
static_cast
<
size_t
>
(
pre_gelu_out
[
i
].
size
(
0
))}
:
std
::
vector
<
size_t
>
{
static_cast
<
size_t
>
(
pre_gelu_out
[
i
].
size
(
0
)),
static_cast
<
size_t
>
(
pre_gelu_out
[
i
].
size
(
1
))};
te_pre_gelu_out
.
emplace_back
(
make_tensor
(
pre_gelu_out
[
i
].
data_ptr
(),
gelu_shape
,
GetTransformerEngineDType
(
pre_gelu_out
[
i
].
scalar_type
()),
nullptr
,
nullptr
,
nullptr
));
}
for
(
size_t
i
=
0
;
i
<
workspace
.
size
();
i
++
)
{
te_workspace
.
emplace_back
(
make_tensor
(
workspace
[
i
].
data_ptr
(),
{
workspaceSize
},
DType
::
kByte
,
nullptr
,
nullptr
,
nullptr
));
}
for
(
size_t
i
=
0
;
i
<
A
.
size
();
i
++
)
{
auto
te_A
=
makeTransformerEngineTensor
(
A
[
i
],
none
);
auto
te_B
=
makeTransformerEngineTensor
(
B
[
i
],
none
);
nvte_multi_stream_cublas_batchgemm
(
te_A
.
data
(),
te_B
.
data
(),
te_D
.
data
(),
te_bias
.
data
(),
te_pre_gelu_out
.
data
(),
te_A
.
size
(),
transa
,
transb
,
grad
,
te_workspace
.
data
(),
accumulate
,
use_split_accumulator
,
math_sm_count
,
at
::
cuda
::
getCurrentCUDAStream
());
}
// if there is single output
at
::
Tensor
out_tensor
;
auto
size_t_shape
=
pytorch
::
detail
::
getGemmOutputShape
(
te_A
.
shape
(),
transa
,
te_B
.
shape
(),
transb
);
bool
D_numel_is_zero
=
false
;
std
::
vector
<
int64_t
>
D_shape
;
for
(
size_t
t
:
size_t_shape
)
{
D_shape
.
push_back
(
t
);
if
(
t
==
0
)
{
D_numel_is_zero
=
true
;
}
}
auto
dtype
=
GetATenDType
(
D_type
);
auto
opts
=
torch
::
TensorOptions
().
dtype
(
dtype
).
device
(
torch
::
kCUDA
);
if
(
single_output
)
{
if
(
output_data_ptr
==
nullptr
)
{
out_tensor
=
at
::
empty
(
D_shape
,
opts
);
}
else
{
// We need to check !D_numel_is_zero because if the final input portion has zero elements,
// output_data_ptr would point beyond the allocated memory of D. This would cause
// at::from_blob to fail as it would reference memory not allocated by CUDA.
if
(
!
D_numel_is_zero
)
{
out_tensor
=
at
::
from_blob
(
output_data_ptr
,
D_shape
,
opts
);
}
}
char
*
char_ptr
=
reinterpret_cast
<
char
*>
(
output_data_ptr
);
char_ptr
+=
D_shape
[
0
]
*
D_shape
[
1
]
*
(
*
D
)[
0
].
element_size
();
output_data_ptr
=
reinterpret_cast
<
void
*>
(
char_ptr
);
D_vectors
.
emplace_back
(
out_tensor
);
}
else
{
if
(
D
==
std
::
nullopt
)
{
auto
opts
=
torch
::
TensorOptions
().
dtype
(
dtype
).
device
(
torch
::
kCUDA
);
out_tensor
=
at
::
empty
(
D_shape
,
opts
);
D_vectors
.
emplace_back
(
out_tensor
);
transformer_engine
::
DType
reverse_map_dtype
(
int64_t
dtype
)
{
if
(
dtype
>=
0
&&
dtype
<
static_cast
<
int64_t
>
(
transformer_engine
::
DType
::
kNumTypes
))
{
return
static_cast
<
transformer_engine
::
DType
>
(
dtype
);
}
else
{
out_tensor
=
(
*
D
)[
i
];
}
}
if
(
te_A
.
numel
()
==
0
||
te_B
.
numel
()
==
0
)
{
if
(
out_tensor
.
numel
()
!=
0
&&
!
accumulate
)
out_tensor
.
zero_
();
if
(
bias
[
i
].
numel
()
!=
0
&&
grad
)
{
bias
[
i
].
zero_
();
NVTE_ERROR
(
"Type not supported."
);
}
if
(
pre_gelu_out
[
i
].
numel
()
!=
0
)
pre_gelu_out
[
i
].
zero_
();
continue
;
}
auto
te_D
=
makeTransformerEngineTensor
(
out_tensor
);
auto
te_bias
=
makeTransformerEngineTensor
(
bias
[
i
]);
auto
te_pre_gelu_out
=
makeTransformerEngineTensor
(
pre_gelu_out
[
i
]);
}
const
auto
gelu_shape
=
pre_gelu_out
[
i
].
data_ptr
()
==
nullptr
?
std
::
vector
<
size_t
>
{
static_cast
<
size_t
>
(
te_pre_gelu_out
.
size
(
0
))}
:
std
::
vector
<
size_t
>
{
static_cast
<
size_t
>
(
te_pre_gelu_out
.
size
(
0
)),
static_cast
<
size_t
>
(
te_pre_gelu_out
.
size
(
1
))};
std
::
vector
<
at
::
Tensor
>
te_batchgemm_ts
(
std
::
vector
<
at
::
Tensor
>
A
,
at
::
Tensor
A_scale_inverse
,
int64_t
A_offset
,
int64_t
A_type
,
int64_t
transa
,
std
::
vector
<
at
::
Tensor
>
B
,
at
::
Tensor
B_scale_inverse
,
int64_t
B_offset
,
int64_t
B_type
,
int64_t
transb
,
std
::
vector
<
at
::
Tensor
>
D
,
int64_t
D_offset
,
at
::
Tensor
D_scale
,
int64_t
D_type
,
at
::
Tensor
D_amax
,
std
::
vector
<
at
::
Tensor
>
bias
,
int64_t
bias_type
,
std
::
vector
<
at
::
Tensor
>
pre_gelu_out
,
int64_t
grad
,
std
::
vector
<
at
::
Tensor
>
workspace
,
int64_t
workspaceSize
,
int64_t
accumulate
,
int64_t
use_split_accumulator
)
{
using
namespace
transformer_engine
;
using
namespace
transformer_engine
::
pytorch
;
// cast inputs to types accepted by te_gemm
transformer_engine
::
DType
A_type_arg
=
reverse_map_dtype
(
A_type
);
bool
transa_arg
=
static_cast
<
bool
>
(
transa
);
transformer_engine
::
DType
B_type_arg
=
reverse_map_dtype
(
B_type
);
bool
transb_arg
=
static_cast
<
bool
>
(
transb
);
transformer_engine
::
DType
D_type_arg
=
reverse_map_dtype
(
D_type
);
transformer_engine
::
DType
bias_type_arg
=
reverse_map_dtype
(
bias_type
);
bool
grad_arg
=
static_cast
<
bool
>
(
grad
);
size_t
workspaceSize_arg
=
static_cast
<
size_t
>
(
workspaceSize
);
bool
accumulate_arg
=
static_cast
<
bool
>
(
accumulate
);
bool
use_split_accumulator_arg
=
static_cast
<
bool
>
(
use_split_accumulator
);
DType
gelu_type
=
bias_type
;
te_pre_gelu_out
=
makeTransformerEngineTensor
(
get_data_ptr
(
pre_gelu_out
[
i
]),
gelu_shape
,
gelu_type
);
// Set an external SM Margin to all the GEMMs.
// This comes in handy when DP is overlapped with GEMMs
te_A_vector
.
emplace_back
(
te_A
.
data
());
te_B_vector
.
emplace_back
(
te_B
.
data
());
te_D_vector
.
emplace_back
(
te_D
.
data
());
te_bias_vector
.
emplace_back
(
te_bias
.
data
());
te_pre_gelu_out_vector
.
emplace_back
(
te_pre_gelu_out
.
data
());
const
int
sm_count
=
transformer_engine
::
cuda
::
sm_count
();
int
num_math_sms
=
sm_count
-
transformer_engine
::
getenv
<
int
>
(
"NVTE_EXT_MARGIN_SM"
,
sm_count
);
wrappers
.
emplace_back
(
std
::
move
(
te_A
));
wrappers
.
emplace_back
(
std
::
move
(
te_B
));
wrappers
.
emplace_back
(
std
::
move
(
te_D
));
wrappers
.
emplace_back
(
std
::
move
(
te_bias
));
wrappers
.
emplace_back
(
std
::
move
(
te_pre_gelu_out
));
}
for
(
size_t
i
=
0
;
i
<
workspace
.
size
();
i
++
)
{
auto
wsp
=
makeTransformerEngineTensor
(
workspace
[
i
].
data_ptr
(),
{
workspaceSize
},
DType
::
kByte
);
te_workspace_vector
.
emplace_back
(
wsp
.
data
());
wrappers
.
emplace_back
(
std
::
move
(
wsp
));
}
// For now, we only have multi-stream cublas backend.
nvte_multi_stream_cublas_batchgemm
(
te_A_vector
.
data
(),
te_B_vector
.
data
(),
te_D_vector
.
data
(),
te_bias_vector
.
data
(),
te_pre_gelu_out_vector
.
data
(),
te_A_vector
.
size
(),
transa
,
transb
,
grad
,
te_workspace_vector
.
data
(),
accumulate
,
use_split_accumulator
,
math_sm_count
,
at
::
cuda
::
getCurrentCUDAStream
());
return
bias
;
te_batchgemm
(
A
,
A_scale_inverse
,
A_offset
,
A_type_arg
,
transa_arg
,
B
,
B_scale_inverse
,
B_offset
,
B_type_arg
,
transb_arg
,
D
,
D_offset
,
D_scale
,
D_type_arg
,
D_amax
,
bias
,
bias_type_arg
,
pre_gelu_out
,
grad_arg
,
workspace
,
workspaceSize_arg
,
accumulate_arg
,
use_split_accumulator_arg
,
num_math_sms
);
return
D
;
}
#endif
transformer_engine/pytorch/csrc/extensions/pybind.cpp
View file @
9d0f1c9b
...
...
@@ -175,7 +175,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m
.
def
(
"te_general_grouped_gemm"
,
&
te_general_grouped_gemm
,
"Grouped GEMM"
);
#ifdef USE_ROCM
m
.
def
(
"te_
general_
batch
ed_
gemm"
,
&
te_
general_
batch
ed_
gemm
,
"Batched GEMM"
);
/// rocblas
m
.
def
(
"te_batchgemm
_ts
"
,
&
te_batchgemm
_ts
,
"Batched GEMM"
);
/// rocblas
#endif
m
.
def
(
"fp8_transpose"
,
&
fp8_transpose
,
"Transpose with FP8 I/O"
,
py
::
arg
(
"input"
),
py
::
arg
(
"dtype"
),
py
::
kw_only
(),
py
::
arg
(
"out"
),
py
::
call_guard
<
py
::
gil_scoped_release
>
());
...
...
transformer_engine/pytorch/module/batched_linear.py
View file @
9d0f1c9b
This diff is collapsed.
Click to expand it.
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment