Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
9d0f1c9b
Commit
9d0f1c9b
authored
May 08, 2025
by
yuguo
Browse files
[DCU] add batchgemm test
parent
e8f92b93
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
664 additions
and
445 deletions
+664
-445
qa/L0_pytorch_unittest/test.sh
qa/L0_pytorch_unittest/test.sh
+1
-0
setup.py
setup.py
+1
-1
tests/pytorch/test_batched_linear.py
tests/pytorch/test_batched_linear.py
+295
-0
transformer_engine/common/gemm/cublaslt_gemm.cu
transformer_engine/common/gemm/cublaslt_gemm.cu
+3
-0
transformer_engine/pytorch/cpp_extensions/gemm.py
transformer_engine/pytorch/cpp_extensions/gemm.py
+42
-48
transformer_engine/pytorch/csrc/extensions.h
transformer_engine/pytorch/csrc/extensions.h
+17
-7
transformer_engine/pytorch/csrc/extensions/gemm.cpp
transformer_engine/pytorch/csrc/extensions/gemm.cpp
+87
-106
transformer_engine/pytorch/csrc/extensions/pybind.cpp
transformer_engine/pytorch/csrc/extensions/pybind.cpp
+1
-1
transformer_engine/pytorch/module/batched_linear.py
transformer_engine/pytorch/module/batched_linear.py
+217
-282
No files found.
qa/L0_pytorch_unittest/test.sh
View file @
9d0f1c9b
...
@@ -26,6 +26,7 @@ python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_sanity.py || test_fail "test
...
@@ -26,6 +26,7 @@ python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_sanity.py || test_fail "test
python3
-m
pytest
-v
-s
$TE_PATH
/tests/pytorch/test_recipe.py
||
test_fail
"test_recipe.py"
python3
-m
pytest
-v
-s
$TE_PATH
/tests/pytorch/test_recipe.py
||
test_fail
"test_recipe.py"
python3
-m
pytest
-v
-s
$TE_PATH
/tests/pytorch/test_deferred_init.py
||
test_fail
"test_deferred_init.py"
python3
-m
pytest
-v
-s
$TE_PATH
/tests/pytorch/test_deferred_init.py
||
test_fail
"test_deferred_init.py"
PYTORCH_JIT
=
0
NVTE_TORCH_COMPILE
=
0
NVTE_ALLOW_NONDETERMINISTIC_ALGO
=
0 python3
-m
pytest
-v
-s
$TE_PATH
/tests/pytorch/test_numerics.py
||
test_fail
"test_numerics.py"
PYTORCH_JIT
=
0
NVTE_TORCH_COMPILE
=
0
NVTE_ALLOW_NONDETERMINISTIC_ALGO
=
0 python3
-m
pytest
-v
-s
$TE_PATH
/tests/pytorch/test_numerics.py
||
test_fail
"test_numerics.py"
PYTORCH_JIT
=
0
NVTE_TORCH_COMPILE
=
0
NVTE_ALLOW_NONDETERMINISTIC_ALGO
=
0 python3
-m
pytest
-v
-s
$TE_PATH
/tests/pytorch/test_batched_linear.py
||
test_fail
"test_batched_linear.py"
PYTORCH_JIT
=
0
NVTE_TORCH_COMPILE
=
0
NVTE_ALLOW_NONDETERMINISTIC_ALGO
=
0 python3
-m
pytest
-v
-s
$TE_PATH
/tests/pytorch/test_cuda_graphs.py
||
test_fail
"test_cuda_graphs.py"
PYTORCH_JIT
=
0
NVTE_TORCH_COMPILE
=
0
NVTE_ALLOW_NONDETERMINISTIC_ALGO
=
0 python3
-m
pytest
-v
-s
$TE_PATH
/tests/pytorch/test_cuda_graphs.py
||
test_fail
"test_cuda_graphs.py"
python3
-m
pytest
-v
-s
$TE_PATH
/tests/pytorch/test_jit.py
||
test_fail
"test_jit.py"
python3
-m
pytest
-v
-s
$TE_PATH
/tests/pytorch/test_jit.py
||
test_fail
"test_jit.py"
python3
-m
pytest
-v
-s
$TE_PATH
/tests/pytorch/test_fused_rope.py
||
test_fail
"test_fused_rope.py"
python3
-m
pytest
-v
-s
$TE_PATH
/tests/pytorch/test_fused_rope.py
||
test_fail
"test_fused_rope.py"
...
...
setup.py
View file @
9d0f1c9b
...
@@ -4,7 +4,7 @@
...
@@ -4,7 +4,7 @@
"""Installation script."""
"""Installation script."""
# NVTE_FRAMEWORK=pytorch NVTE_USE_ROCM=1 NVTE_USE_HIPBLASLT=1 NVTE_USE_ROCBLAS=1 CMAKE_PREFIX_PATH=/opt/dtk/lib/cmake/amd_comgr/ MPI_HOME=/opt/mpi/ NVTE_UB_WITH_MPI=1 CXX=hipcc pip3 install . -v
# NVTE_FRAMEWORK=pytorch NVTE_USE_ROCM=1 NVTE_USE_HIPBLASLT=1 NVTE_USE_ROCBLAS=1 CMAKE_PREFIX_PATH=/opt/dtk/lib/cmake/amd_comgr/ MPI_HOME=/opt/mpi/ NVTE_UB_WITH_MPI=1 CXX=hipcc pip3 install . -v
# VTE_FRAMEWORK=pytorch NVTE_USE_ROCM=1 NVTE_USE_HIPBLASLT=1 NVTE_USE_ROCBLAS=1 CMAKE_PREFIX_PATH=/opt/dtk/lib/cmake/amd_comgr/ MPI_HOME=/opt/mpi/ NVTE_UB_WITH_MPI=1 CXX=hipcc PYTHONPATH=/home/TransformerEngine/3rdparty/hipify_torch:$PYTHONPATH python3 setup.py bdist_wheel
#
N
VTE_FRAMEWORK=pytorch NVTE_USE_ROCM=1 NVTE_USE_HIPBLASLT=1 NVTE_USE_ROCBLAS=1 CMAKE_PREFIX_PATH=/opt/dtk/lib/cmake/amd_comgr/ MPI_HOME=/opt/mpi/ NVTE_UB_WITH_MPI=1 CXX=hipcc PYTHONPATH=/home/TransformerEngine/3rdparty/hipify_torch:$PYTHONPATH python3 setup.py bdist_wheel
import
os
import
os
import
sys
import
sys
...
...
tests/pytorch/test_batched_linear.py
0 → 100644
View file @
9d0f1c9b
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
from
collections
import
OrderedDict
import
math
import
os
from
typing
import
Dict
,
List
,
Tuple
,
Optional
import
pytest
import
copy
import
random
import
torch
import
torch.nn
as
nn
from
torch.nn
import
Parameter
from
torch.utils.cpp_extension
import
IS_HIP_EXTENSION
from
transformer_engine.pytorch.fp8
import
(
FP8GlobalStateManager
,
fp8_autocast
,
fp8_model_init
,
)
from
transformer_engine.pytorch.utils
import
(
init_method_normal
,
scaled_init_method_normal
,
attention_mask_func
,
is_bf16_compatible
,
)
from
transformer_engine.pytorch
import
(
DotProductAttention
,
LayerNormLinear
,
LayerNormMLP
,
Linear
,
GroupedLinear
,
BatchedLinear
,
MultiheadAttention
,
RMSNorm
,
TransformerLayer
,
LayerNorm
,
Fp8Padding
,
Fp8Unpadding
,
)
from
transformer_engine.pytorch
import
torch_version
from
transformer_engine.pytorch.dot_product_attention.inference
import
InferenceParams
from
transformer_engine.pytorch.distributed
import
checkpoint
as
te_checkpoint
from
transformer_engine.pytorch.cpp_extensions
import
general_gemm
,
general_grouped_gemm
,
batchgemm
from
transformer_engine.pytorch.tensor.float8_tensor
import
Float8Quantizer
from
transformer_engine.pytorch.module.base
import
get_multi_stream_cublas_workspace
,
get_workspace
from
transformer_engine.pytorch.utils
import
get_device_compute_capability
from
transformer_engine.common
import
recipe
import
transformer_engine_torch
as
tex
# Only run FP8 tests on supported devices.
fp8_available
,
reason_for_no_fp8
=
FP8GlobalStateManager
.
is_fp8_available
()
mxfp8_available
,
reason_for_no_mxfp8
=
FP8GlobalStateManager
.
is_mxfp8_available
()
sm_80plus
=
get_device_compute_capability
()
>=
(
8
,
0
)
seed
=
1234
torch
.
manual_seed
(
seed
)
torch
.
cuda
.
manual_seed
(
seed
)
# Record initial RNG state from script run.
_cpu_rng_state
=
torch
.
get_rng_state
()
_cuda_rng_state
=
torch
.
cuda
.
get_rng_state
()
if
torch_version
()
>=
(
2
,
7
,
0
):
torch
.
_dynamo
.
config
.
recompile_limit
=
16
else
:
torch
.
_dynamo
.
config
.
cache_size_limit
=
16
class
ModelConfig
:
def
__init__
(
self
,
hidden_size
,
eps
,
num_attention_heads
,
embed
,
num_layers
,
seq_len
):
self
.
hidden_size
=
hidden_size
self
.
eps
=
eps
self
.
num_attention_heads
=
num_attention_heads
self
.
embed
=
embed
self
.
num_layers
=
num_layers
self
.
seq_len
=
seq_len
model_configs
=
{
"small"
:
ModelConfig
(
128
,
1e-5
,
8
,
36
,
4
,
128
),
"126m"
:
ModelConfig
(
768
,
1e-5
,
12
,
64
,
12
,
2048
),
}
model_configs_inference
=
{
# hidden_size, eps, num_attention_heads, embed, num_layers, seq_len
"126m"
:
ModelConfig
(
768
,
1e-5
,
12
,
64
,
12
,
256
),
}
backends_inference
=
[
"FlashAttention"
,
"UnfusedAttention"
,
"FusedAttention"
]
module_inference
=
[
"TransformerLayer"
,
"MultiheadAttention"
]
input_formats_inference
=
[
"sbhd"
,
"bshd"
]
param_types
=
[
torch
.
float32
,
torch
.
float16
]
if
is_bf16_compatible
():
# bf16 requires sm_80 or higher
param_types
.
append
(
torch
.
bfloat16
)
batch_sizes
=
[
1
,
2
]
all_boolean
=
[
True
,
False
]
all_activations
=
[
"gelu"
,
"relu"
,
"reglu"
,
"geglu"
,
"swiglu"
,
"qgelu"
,
"srelu"
]
all_normalizations
=
[
"LayerNorm"
,
"RMSNorm"
]
mask_types
=
[
"causal"
,
"no_mask"
]
fp8_recipes
=
[
recipe
.
MXFP8BlockScaling
(),
recipe
.
DelayedScaling
(),
recipe
.
Float8CurrentScaling
(),
]
def
get_causal_attn_mask
(
sq
:
int
)
->
torch
.
Tensor
:
return
torch
.
triu
(
torch
.
ones
(
sq
,
sq
,
device
=
"cuda"
),
diagonal
=
1
).
bool
()
def
dtype_tols
(
dtype
:
torch
.
dtype
)
->
Dict
[
str
,
float
]:
"""Estimated numerical error for a datatype
Based on tolerances for torch.testing.assert_close.
"""
if
dtype
==
torch
.
float32
:
return
dict
(
rtol
=
1.3e-6
,
atol
=
1e-5
)
if
dtype
==
torch
.
float16
:
return
dict
(
rtol
=
1e-3
,
atol
=
1e-5
)
if
dtype
==
torch
.
bfloat16
:
return
dict
(
rtol
=
1.6e-2
,
atol
=
1e-5
)
raise
ValueError
(
f
"Unsuppored dtype (
{
dtype
}
)"
)
def
assert_allclose
(
l1
:
List
[
torch
.
Tensor
],
l2
:
List
[
torch
.
Tensor
],
atol
:
float
,
rtol
:
float
=
None
)
->
bool
:
"""Ensures two lists are equal."""
assert
len
(
l1
)
==
len
(
l2
),
"Unequal number of outputs."
for
i
,
(
t1
,
t2
)
in
enumerate
(
zip
(
l1
,
l2
)):
tols
=
dict
(
atol
=
atol
)
if
rtol
is
not
None
:
tols
[
"rtol"
]
=
rtol
result
=
torch
.
allclose
(
t1
,
t2
,
**
tols
)
if
not
result
:
diff
=
torch
.
abs
(
t1
-
t2
)
tol
=
atol
+
(
rtol
*
torch
.
abs
(
t2
))
exceed_mask
=
diff
>
tol
if
exceed_mask
.
any
():
indices
=
torch
.
nonzero
(
exceed_mask
,
as_tuple
=
True
)
max_diff
=
diff
[
exceed_mask
].
max
()
max_idx
=
(
diff
[
exceed_mask
]
==
max_diff
).
nonzero
(
as_tuple
=
True
)[
0
][
0
]
max_location
=
[
idx
[
max_idx
].
item
()
for
idx
in
indices
]
msg
=
(
f
"Outputs not close enough in tensor at idx=
{
i
}
. "
f
"Maximum difference at location
{
max_location
}
"
f
"with
{
t1
[
exceed_mask
][
max_idx
].
item
()
}
vs
{
t2
[
exceed_mask
][
max_idx
].
item
()
}
"
f
"(diff
{
max_diff
.
item
()
}
)."
)
raise
AssertionError
(
msg
)
def
reset_rng_states
()
->
None
:
"""revert back to initial RNG state."""
torch
.
set_rng_state
(
_cpu_rng_state
)
torch
.
cuda
.
set_rng_state
(
_cuda_rng_state
)
@
pytest
.
fixture
(
autouse
=
True
)
def
reset_global_fp8_state
():
yield
FP8GlobalStateManager
.
reset
()
def
_test_batched_linear_accuracy
(
block
,
num_gemms
,
bs
,
dtype
,
config
,
recipe
,
fp8
,
fuse_wgrad_accumulation
):
reset_rng_states
()
if
fp8
:
FP8GlobalStateManager
.
reset
()
inp_hidden_states
=
torch
.
randn
(
(
config
.
seq_len
,
bs
,
config
.
hidden_size
),
dtype
=
dtype
,
device
=
"cuda"
,
requires_grad
=
True
,
)
inp_hidden_states
.
retain_grad
()
assert
config
.
seq_len
%
num_gemms
==
0
m_splits
=
torch
.
tensor
([
config
.
seq_len
//
num_gemms
for
i
in
range
(
num_gemms
)])
assert
m_splits
.
sum
()
==
config
.
seq_len
and
len
(
m_splits
)
==
num_gemms
with
fp8_autocast
(
enabled
=
fp8
,
fp8_recipe
=
recipe
):
if
isinstance
(
block
,
BatchedLinear
):
m_splits
=
m_splits
*
bs
out
=
block
(
inp_hidden_states
,
m_splits
.
tolist
())
else
:
out
=
torch
.
cat
(
[
block
[
i
](
inp
)
for
i
,
inp
in
enumerate
(
torch
.
split
(
inp_hidden_states
,
m_splits
.
tolist
()))
]
)
loss
=
out
.
sum
()
loss
.
backward
()
torch
.
cuda
.
synchronize
()
outputs
=
[
out
,
inp_hidden_states
.
grad
]
return
outputs
@
pytest
.
mark
.
parametrize
(
"dtype"
,
param_types
)
@
pytest
.
mark
.
parametrize
(
"num_gemms"
,
[
4
,
8
])
@
pytest
.
mark
.
parametrize
(
"bs"
,
batch_sizes
)
@
pytest
.
mark
.
parametrize
(
"model"
,
[
"126m"
])
@
pytest
.
mark
.
parametrize
(
"fp8"
,
all_boolean
)
@
pytest
.
mark
.
parametrize
(
"recipe"
,
fp8_recipes
)
@
pytest
.
mark
.
parametrize
(
"fp8_model_params"
,
all_boolean
)
@
pytest
.
mark
.
parametrize
(
"fuse_wgrad_accumulation"
,
all_boolean
)
def
test_batched_linear_accuracy
(
dtype
,
num_gemms
,
bs
,
model
,
fp8
,
recipe
,
fp8_model_params
,
fuse_wgrad_accumulation
,
parallel_mode
=
None
,
):
batch_num
=
int
(
os
.
getenv
(
"NVTE_MOE_BATCHCOUNT"
,
"2"
))
if
fp8
and
not
fp8_available
:
pytest
.
skip
(
reason_for_no_fp8
)
if
recipe
.
mxfp8
()
and
not
mxfp8_available
:
pytest
.
skip
(
reason_for_no_mxfp8
)
if
fp8
and
recipe
.
mxfp8
():
# TODO(ksivamani): debug mismatches
pytest
.
skip
(
"MXFP8 unsupported for batched linear."
)
if
fp8
and
recipe
.
float8_current_scaling
():
pytest
.
skip
(
"Float8 Current Scaling unsupported for batched linear."
)
config
=
model_configs
[
model
]
if
config
.
seq_len
%
16
!=
0
and
fp8
:
pytest
.
skip
(
"FP8 requires sequence length to be divisible by 16."
)
with
fp8_model_init
(
enabled
=
fp8
and
fp8_model_params
,
recipe
=
recipe
):
batched_linear
=
BatchedLinear
(
num_gemms
,
config
.
hidden_size
,
4
*
config
.
hidden_size
,
bias
=
False
,
params_dtype
=
dtype
,
parallel_mode
=
parallel_mode
,
device
=
"cuda"
,
fuse_wgrad_accumulation
=
fuse_wgrad_accumulation
,
).
eval
()
sequential_linear
=
torch
.
nn
.
ModuleList
(
[
Linear
(
config
.
hidden_size
,
4
*
config
.
hidden_size
,
bias
=
False
,
params_dtype
=
dtype
,
parallel_mode
=
parallel_mode
,
device
=
"cuda"
,
fuse_wgrad_accumulation
=
fuse_wgrad_accumulation
,
).
eval
()
for
_
in
range
(
num_gemms
)
]
)
# Share params
with
torch
.
no_grad
():
for
i
in
range
(
num_gemms
//
batch_num
):
weight
=
getattr
(
batched_linear
,
f
"weight
{
i
}
"
).
clone
()
# bias = getattr(batched_linear, f"bias{i}").clone()
if
fuse_wgrad_accumulation
:
weight_i
=
getattr
(
batched_linear
,
f
"weight
{
i
}
"
)
weight_i
.
main_grad
=
torch
.
rand_like
(
weight_i
,
dtype
=
torch
.
float32
)
for
j
in
range
(
batch_num
):
sequential_linear
[
i
*
batch_num
+
j
].
weight
=
Parameter
(
weight
[
weight
.
shape
[
0
]
//
batch_num
*
j
:
weight
.
shape
[
0
]
//
batch_num
*
(
j
+
1
)].
clone
())
# sequential_linear[i * batch_num + j].bias = Parameter(bias[bias.shape[0] // batch_num * j : bias.shape[0] // batch_num * (j + 1)].clone())
if
fuse_wgrad_accumulation
:
sequential_linear
[
i
*
batch_num
+
j
].
weight
.
main_grad
=
weight_i
.
main_grad
[
weight_i
.
main_grad
.
shape
[
0
]
//
batch_num
*
j
:
weight_i
.
main_grad
.
shape
[
0
]
//
batch_num
*
(
j
+
1
)].
clone
()
outputs_ref
=
_test_batched_linear_accuracy
(
sequential_linear
,
num_gemms
,
bs
,
dtype
,
config
,
recipe
,
fp8
,
fuse_wgrad_accumulation
)
outputs
=
_test_batched_linear_accuracy
(
batched_linear
,
num_gemms
,
bs
,
dtype
,
config
,
recipe
,
fp8
,
fuse_wgrad_accumulation
)
# Shoule be bit-wise match
for
i
,
(
o
,
o_ref
)
in
enumerate
(
zip
(
outputs
,
outputs_ref
)):
torch
.
testing
.
assert_close
(
o
,
o_ref
,
rtol
=
6e-3
,
atol
=
6e-3
)
if
__name__
==
"__main__"
:
test_batched_linear_accuracy
(
torch
.
float32
,
2
,
1
,
"126m"
,
False
,
recipe
.
Float8CurrentScaling
(),
True
,
True
)
transformer_engine/common/gemm/cublaslt_gemm.cu
View file @
9d0f1c9b
...
@@ -778,6 +778,9 @@ void nvte_cublas_batchgemm(const NVTETensor A, const NVTETensor B, NVTETensor D,
...
@@ -778,6 +778,9 @@ void nvte_cublas_batchgemm(const NVTETensor A, const NVTETensor B, NVTETensor D,
const
Tensor
*
biasTensor
=
reinterpret_cast
<
const
Tensor
*>
(
bias
);
const
Tensor
*
biasTensor
=
reinterpret_cast
<
const
Tensor
*>
(
bias
);
Tensor
*
outputGelu
=
reinterpret_cast
<
Tensor
*>
(
pre_gelu_out
);
Tensor
*
outputGelu
=
reinterpret_cast
<
Tensor
*>
(
pre_gelu_out
);
Tensor
*
wspace
=
reinterpret_cast
<
Tensor
*>
(
workspace
);
Tensor
*
wspace
=
reinterpret_cast
<
Tensor
*>
(
workspace
);
if
((
biasTensor
->
data
.
dptr
!=
nullptr
)
||
(
outputGelu
->
data
.
dptr
!=
nullptr
))
{
NVTE_ERROR
(
"MOE batchgemm not surpport bias or gelu."
);
}
int
m
,
n
,
k
;
int
m
,
n
,
k
;
if
(
!
transa
&&
transb
)
{
if
(
!
transa
&&
transb
)
{
...
...
transformer_engine/pytorch/cpp_extensions/gemm.py
View file @
9d0f1c9b
...
@@ -18,7 +18,7 @@ from ..tensor._internal.mxfp8_tensor_base import MXFP8TensorBase
...
@@ -18,7 +18,7 @@ from ..tensor._internal.mxfp8_tensor_base import MXFP8TensorBase
__all__
=
[
__all__
=
[
"general_gemm"
,
"general_gemm"
,
"general_grouped_gemm"
,
"general_grouped_gemm"
,
"
general_
batch
ed_
gemm"
,
"batchgemm"
,
]
]
...
@@ -226,84 +226,78 @@ def general_grouped_gemm(
...
@@ -226,84 +226,78 @@ def general_grouped_gemm(
return
out
,
bias
,
gelu_input
return
out
,
bias
,
gelu_input
def
general_
batch
ed_
gemm
(
def
batchgemm
(
A
:
List
[
torch
.
Tensor
],
A
:
List
[
torch
.
Tensor
],
B
:
List
[
torch
.
Tensor
],
B
:
List
[
torch
.
Tensor
],
out
:
List
[
torch
.
Tensor
],
out
:
List
[
torch
.
Tensor
],
out_
dtype
:
torch
.
dtype
,
dtype
:
torch
.
dtype
,
workspaces
:
List
[
torch
.
Tensor
],
workspaces
:
List
[
torch
.
Tensor
],
layout
:
str
=
"TN"
,
m_splits
:
Optional
[
List
[
int
]]
=
None
,
gelu
:
bool
=
False
,
gelu
:
bool
=
False
,
grad
=
False
,
gelu_input
:
Optional
[
List
[
torch
.
Tensor
]]
=
None
,
grad
:
bool
=
False
,
accumulate
:
bool
=
False
,
accumulate
:
bool
=
False
,
layout
:
str
=
"TN"
,
bias
:
Optional
[
List
[
torch
.
Tensor
]]
=
None
,
bias
:
Optional
[
List
[
torch
.
Tensor
]]
=
None
,
use_bias
:
bool
=
False
,
use_bias
:
bool
=
False
,
use_split_accumulator
:
bool
=
False
,
)
->
Tuple
[
Union
[
torch
.
Tensor
,
None
],
...]:
D_dtype
:
Optional
[
tex
.
DType
]
=
None
,
"""Non FP8 batch GEMM."""
single_output
=
False
,
)
->
Tuple
[
List
[
torch
.
Tensor
],
...]:
"""
TN layout Grouped GEMM with fp8 inputs.
"""
num_gemms
=
len
(
A
)
assert
layout
in
(
"TN"
,
"NN"
,
"NT"
),
f
"GEMM layout
{
layout
}
not supported."
transa
=
layout
[
0
]
==
"T"
transa
=
layout
[
0
]
==
"T"
transb
=
layout
[
1
]
==
"T"
transb
=
layout
[
1
]
==
"T"
num_gemms
=
len
(
A
)
empty_tensor
=
torch
.
Tensor
()
empty_tensors
=
[
torch
.
Tensor
()]
*
num_gemms
if
gelu
and
not
grad
:
gelu_input
=
[
torch
.
empty_like
(
o
,
dtype
=
dtype
,
memory_format
=
torch
.
contiguous_format
)
for
o
in
out
]
elif
not
gelu
:
gelu_input
=
empty_tensors
# assert [a.is_contiguous() for a in A]
# assert [b.is_contiguous() for b in B]
if
isinstance
(
A
[
0
],
Float8TensorBase
):
for
a
,
b
in
zip
(
A
,
B
):
assert_dim_for_fp8_exec
(
a
.
_data
)
assert_dim_for_fp8_exec
(
b
.
_data
)
empty_tensor
=
_empty_tensor
()
empty_tensors
=
[
empty_tensor
]
*
num_gemms
# Use bfloat16 as default bias_dtype
gelu_input
=
empty_tensors
out_dtype
=
TE_DType
[
out
[
0
].
dtype
]
if
D_dtype
is
None
else
D_dtype
sm_count
=
get_sm_count
()
if
grad
and
use_bias
:
if
grad
and
use_bias
:
grad_bias
=
[
grad_bias
=
[
torch
.
empty
(
B
[
i
].
shape
[
1
],
dtype
=
out
[
0
].
dtype
,
device
=
"cuda"
)
for
i
in
range
(
num_gemms
)
torch
.
empty
(
B
[
i
].
shape
[
1
],
dtype
=
out
[
0
].
dtype
,
device
=
"cuda"
)
for
i
in
range
(
num_gemms
)
]
]
else
:
else
:
grad_bias
=
empty_tensors
grad_bias
=
empty_tensors
bias
=
bias
if
use_bias
else
empty_tensors
bias
=
bias
if
use_bias
else
empty_tensors
assert
(
A
[
0
].
dtype
==
dtype
and
B
[
0
].
dtype
==
dtype
),
f
"Expected dtype=
{
dtype
}
, but found A.dtype=
{
A
[
0
].
dtype
}
and B.dtype=
{
B
[
0
].
dtype
}
"
input_dtype
=
TE_DType
[
dtype
]
output_dtype
=
TE_DType
[
out
[
0
].
dtype
]
if
use_bias
:
if
use_bias
:
bias_dtype
=
TE_DType
[
grad_bias
[
0
].
dtype
]
if
grad
else
TE_DType
[
bias
[
0
].
dtype
]
bias_dtype
=
TE_DType
[
grad_bias
[
0
].
dtype
]
if
grad
else
TE_DType
[
bias
[
0
].
dtype
]
else
:
else
:
bias_dtype
=
TE_DType
[
torch
.
bfloat16
]
bias_dtype
=
output_dtype
tex
.
te_batchgemm_ts
(
if
gelu
:
gelu_input
=
[
torch
.
empty_like
(
o
,
dtype
=
bias_dtype
,
memory_format
=
torch
.
contiguous_format
)
for
o
in
out
]
# this should differ with respect to single output
bias
=
tex
.
te_general_batched_gemm
(
A
,
A
,
empty_tensor
,
0
,
# A_offset
input_dtype
,
transa
,
transa
,
B
,
B
,
empty_tensor
,
0
,
# B_offset
input_dtype
,
transb
,
transb
,
out
,
out
,
out_dtype
,
0
,
# out_offset
m_splits
,
empty_tensor
,
# out_scale
output_dtype
,
empty_tensor
,
# out_amax
grad_bias
if
grad
else
bias
,
grad_bias
if
grad
else
bias
,
bias_dtype
,
bias_dtype
,
single_output
,
gelu_input
,
# gelu_input
gelu_input
,
# this is pre_gelu_out
grad
,
grad
,
# grad
workspaces
,
workspaces
,
workspaces
[
0
].
shape
[
0
],
workspaces
[
0
].
shape
[
0
],
accumulate
,
accumulate
,
use_split_accumulator
,
False
,
# use_split_accumulator
sm_count
-
int
(
os
.
getenv
(
"NVTE_EXT_MARGIN_SM"
,
str
(
sm_count
))),
)
)
return
out
,
bias
,
gelu_input
return
out
,
grad_bias
,
gelu_input
transformer_engine/pytorch/csrc/extensions.h
View file @
9d0f1c9b
...
@@ -102,13 +102,23 @@ std::optional<std::vector<at::Tensor>> te_general_grouped_gemm(
...
@@ -102,13 +102,23 @@ std::optional<std::vector<at::Tensor>> te_general_grouped_gemm(
bool
use_split_accumulator
,
int
math_sm_count
);
bool
use_split_accumulator
,
int
math_sm_count
);
#ifdef __HIP_PLATFORM_AMD__
#ifdef __HIP_PLATFORM_AMD__
std
::
optional
<
std
::
vector
<
at
::
Tensor
>>
te_general_batched_gemm
(
void
te_batchgemm
(
std
::
vector
<
at
::
Tensor
>
A
,
at
::
Tensor
A_scale_inverse
,
int
A_offset
,
std
::
vector
<
py
::
handle
>
A
,
bool
transa
,
std
::
vector
<
py
::
handle
>
B
,
bool
transb
,
transformer_engine
::
DType
A_type
,
bool
transa
,
std
::
vector
<
at
::
Tensor
>
B
,
std
::
optional
<
std
::
vector
<
at
::
Tensor
>>
D
,
transformer_engine
::
DType
D_type
,
at
::
Tensor
B_scale_inverse
,
int
B_offset
,
transformer_engine
::
DType
B_type
,
std
::
vector
<
int64_t
>
m_splits
,
std
::
vector
<
at
::
Tensor
>
bias
,
bool
transb
,
std
::
vector
<
at
::
Tensor
>
D
,
int
D_offset
,
at
::
Tensor
D_scale
,
transformer_engine
::
DType
bias_type
,
bool
single_output
,
std
::
vector
<
at
::
Tensor
>
pre_gelu_out
,
transformer_engine
::
DType
D_type
,
at
::
Tensor
D_amax
,
bool
grad
,
std
::
vector
<
at
::
Tensor
>
workspace
,
size_t
workspaceSize
,
bool
accumulate
,
std
::
vector
<
at
::
Tensor
>
bias
,
transformer_engine
::
DType
bias_type
,
bool
use_split_accumulator
,
int
math_sm_count
);
std
::
vector
<
at
::
Tensor
>
pre_gelu_out
,
bool
grad
,
std
::
vector
<
at
::
Tensor
>
workspace
,
size_t
workspaceSize
,
bool
accumulate
,
bool
use_split_accumulator
,
int
math_sm_count
);
std
::
vector
<
at
::
Tensor
>
te_batchgemm_ts
(
std
::
vector
<
at
::
Tensor
>
A
,
at
::
Tensor
A_scale_inverse
,
int64_t
A_offset
,
int64_t
A_type
,
int64_t
transa
,
std
::
vector
<
at
::
Tensor
>
B
,
at
::
Tensor
B_scale_inverse
,
int64_t
B_offset
,
int64_t
B_type
,
int64_t
transb
,
std
::
vector
<
at
::
Tensor
>
D
,
int64_t
D_offset
,
at
::
Tensor
D_scale
,
int64_t
D_type
,
at
::
Tensor
D_amax
,
std
::
vector
<
at
::
Tensor
>
bias
,
int64_t
bias_type
,
std
::
vector
<
at
::
Tensor
>
pre_gelu_out
,
int64_t
grad
,
std
::
vector
<
at
::
Tensor
>
workspace
,
int64_t
workspaceSize
,
int64_t
accumulate
,
int64_t
use_split_accumulator
);
#endif
#endif
/***************************************************************************************************
/***************************************************************************************************
...
...
transformer_engine/pytorch/csrc/extensions/gemm.cpp
View file @
9d0f1c9b
...
@@ -424,123 +424,104 @@ std::optional<std::vector<at::Tensor>> te_general_grouped_gemm(
...
@@ -424,123 +424,104 @@ std::optional<std::vector<at::Tensor>> te_general_grouped_gemm(
}
}
#ifdef USE_ROCM
#ifdef USE_ROCM
std
::
optional
<
std
::
vector
<
at
::
Tensor
>>
te_general_batched_gemm
(
void
te_batchgemm
(
std
::
vector
<
at
::
Tensor
>
A
,
at
::
Tensor
A_scale_inverse
,
int
A_offset
,
std
::
vector
<
py
::
handle
>
A
,
bool
transa
,
std
::
vector
<
py
::
handle
>
B
,
bool
transb
,
transformer_engine
::
DType
A_type
,
bool
transa
,
std
::
vector
<
at
::
Tensor
>
B
,
std
::
optional
<
std
::
vector
<
at
::
Tensor
>>
D
,
transformer_engine
::
DType
D_type
,
at
::
Tensor
B_scale_inverse
,
int
B_offset
,
transformer_engine
::
DType
B_type
,
std
::
vector
<
int64_t
>
m_splits
,
std
::
vector
<
at
::
Tensor
>
bias
,
bool
transb
,
std
::
vector
<
at
::
Tensor
>
D
,
int
D_offset
,
at
::
Tensor
D_scale
,
transformer_engine
::
DType
bias_type
,
bool
single_output
,
std
::
vector
<
at
::
Tensor
>
pre_gelu_out
,
transformer_engine
::
DType
D_type
,
at
::
Tensor
D_amax
,
bool
grad
,
std
::
vector
<
at
::
Tensor
>
workspace
,
size_t
workspaceSize
,
bool
accumulate
,
std
::
vector
<
at
::
Tensor
>
bias
,
transformer_engine
::
DType
bias_type
,
bool
use_split_accumulator
,
int
math_sm_count
)
{
std
::
vector
<
at
::
Tensor
>
pre_gelu_out
,
bool
grad
,
std
::
vector
<
at
::
Tensor
>
workspace
,
size_t
workspaceSize
,
bool
accumulate
,
bool
use_split_accumulator
,
int
math_sm_count
)
{
using
namespace
transformer_engine
;
using
namespace
transformer_engine
;
using
namespace
transformer_engine
::
pytorch
;
using
namespace
transformer_engine
::
pytorch
;
std
::
vector
<
NVTETensor
>
te_A_vector
,
te_B_vector
,
te_D_vector
,
te_bias_vector
,
std
::
vector
<
NVTETensor
>
te_A
,
te_B
,
te_D
,
te_bias
,
te_pre_gelu_out
,
te_workspace
;
te_pre_gelu_out_vector
,
te_workspace_vector
;
std
::
vector
<
transformer_engine
::
TensorWrapper
>
tensor_wrappers
;
std
::
vector
<
TensorWrapper
>
wrappers
;
auto
make_tensor
=
[
&
tensor_wrappers
](
void
*
dptr
,
const
std
::
vector
<
size_t
>&
shape
,
std
::
vector
<
at
::
Tensor
>
D_vectors
;
transformer_engine
::
DType
dtype
,
void
*
amax_dptr
,
void
*
scale_dptr
,
void
*
scale_inv_dptr
)
->
NVTETensor
{
auto
none
=
py
::
none
();
tensor_wrappers
.
emplace_back
(
makeTransformerEngineTensor
(
dptr
,
shape
,
dtype
,
amax_dptr
,
scale_dptr
,
scale_inv_dptr
));
std
::
vector
<
size_t
>
single_output_begins
;
return
tensor_wrappers
.
back
().
data
();
std
::
vector
<
size_t
>
single_output_ends
;
};
int
slicing_dim
;
if
(
single_output
&&
D
==
std
::
nullopt
)
{
NVTE_ERROR
(
"not implemented, D should be allocated for single output case."
);
}
void
*
output_data_ptr
;
if
(
single_output
)
{
output_data_ptr
=
(
*
D
)[
0
].
data_ptr
();
}
for
(
size_t
i
=
0
;
i
<
A
.
size
();
i
++
)
{
for
(
size_t
i
=
0
;
i
<
A
.
size
();
i
++
)
{
auto
te_A
=
makeTransformerEngineTensor
(
A
[
i
],
none
);
if
(
A
[
i
].
data_ptr
()
==
nullptr
||
B
[
i
].
data_ptr
()
==
nullptr
)
{
auto
te_B
=
makeTransformerEngineTensor
(
B
[
i
],
none
);
if
(
D
[
i
].
data_ptr
()
!=
nullptr
&&
!
accumulate
)
D
[
i
].
zero_
();
if
(
bias
[
i
].
data_ptr
()
!=
nullptr
)
bias
[
i
].
zero_
();
// if there is single output
if
(
pre_gelu_out
[
i
].
data_ptr
()
!=
nullptr
)
pre_gelu_out
[
i
].
zero_
();
at
::
Tensor
out_tensor
;
auto
size_t_shape
=
pytorch
::
detail
::
getGemmOutputShape
(
te_A
.
shape
(),
transa
,
te_B
.
shape
(),
transb
);
bool
D_numel_is_zero
=
false
;
std
::
vector
<
int64_t
>
D_shape
;
for
(
size_t
t
:
size_t_shape
)
{
D_shape
.
push_back
(
t
);
if
(
t
==
0
)
{
D_numel_is_zero
=
true
;
}
}
auto
dtype
=
GetATenDType
(
D_type
);
auto
opts
=
torch
::
TensorOptions
().
dtype
(
dtype
).
device
(
torch
::
kCUDA
);
if
(
single_output
)
{
if
(
output_data_ptr
==
nullptr
)
{
out_tensor
=
at
::
empty
(
D_shape
,
opts
);
}
else
{
// We need to check !D_numel_is_zero because if the final input portion has zero elements,
// output_data_ptr would point beyond the allocated memory of D. This would cause
// at::from_blob to fail as it would reference memory not allocated by CUDA.
if
(
!
D_numel_is_zero
)
{
out_tensor
=
at
::
from_blob
(
output_data_ptr
,
D_shape
,
opts
);
}
}
char
*
char_ptr
=
reinterpret_cast
<
char
*>
(
output_data_ptr
);
char_ptr
+=
D_shape
[
0
]
*
D_shape
[
1
]
*
(
*
D
)[
0
].
element_size
();
output_data_ptr
=
reinterpret_cast
<
void
*>
(
char_ptr
);
D_vectors
.
emplace_back
(
out_tensor
);
}
else
{
if
(
D
==
std
::
nullopt
)
{
auto
opts
=
torch
::
TensorOptions
().
dtype
(
dtype
).
device
(
torch
::
kCUDA
);
out_tensor
=
at
::
empty
(
D_shape
,
opts
);
D_vectors
.
emplace_back
(
out_tensor
);
}
else
{
out_tensor
=
(
*
D
)[
i
];
}
}
if
(
te_A
.
numel
()
==
0
||
te_B
.
numel
()
==
0
)
{
if
(
out_tensor
.
numel
()
!=
0
&&
!
accumulate
)
out_tensor
.
zero_
();
if
(
bias
[
i
].
numel
()
!=
0
&&
grad
)
{
bias
[
i
].
zero_
();
}
if
(
pre_gelu_out
[
i
].
numel
()
!=
0
)
pre_gelu_out
[
i
].
zero_
();
continue
;
continue
;
}
}
te_A
.
emplace_back
(
make_tensor
(
auto
te_D
=
makeTransformerEngineTensor
(
out_tensor
);
A
[
i
].
data_ptr
(),
{
static_cast
<
size_t
>
(
A
[
i
].
size
(
0
)),
static_cast
<
size_t
>
(
A
[
i
].
size
(
1
))},
auto
te_bias
=
makeTransformerEngineTensor
(
bias
[
i
]);
A_type
,
nullptr
,
nullptr
,
getDataPtr
(
A_scale_inverse
,
A_offset
+
i
)));
auto
te_pre_gelu_out
=
makeTransformerEngineTensor
(
pre_gelu_out
[
i
]);
te_B
.
emplace_back
(
make_tensor
(
B
[
i
].
data_ptr
(),
{
static_cast
<
size_t
>
(
B
[
i
].
size
(
0
)),
static_cast
<
size_t
>
(
B
[
i
].
size
(
1
))},
B_type
,
nullptr
,
nullptr
,
getDataPtr
(
B_scale_inverse
,
B_offset
+
i
)));
te_D
.
emplace_back
(
make_tensor
(
D
[
i
].
data_ptr
(),
{
static_cast
<
size_t
>
(
D
[
i
].
size
(
0
)),
static_cast
<
size_t
>
(
D
[
i
].
size
(
1
))},
D_type
,
getDataPtr
(
D_amax
,
D_offset
+
i
),
getDataPtr
(
D_scale
,
D_offset
+
i
),
nullptr
));
te_bias
.
emplace_back
(
make_tensor
(
bias
[
i
].
data_ptr
(),
{
static_cast
<
size_t
>
(
bias
[
i
].
size
(
0
))},
bias_type
,
nullptr
,
nullptr
,
nullptr
));
const
auto
gelu_shape
=
pre_gelu_out
[
i
].
data_ptr
()
==
nullptr
const
auto
gelu_shape
=
pre_gelu_out
[
i
].
data_ptr
()
==
nullptr
?
std
::
vector
<
size_t
>
{
static_cast
<
size_t
>
(
te_pre_gelu_out
.
size
(
0
))}
?
std
::
vector
<
size_t
>
{
static_cast
<
size_t
>
(
pre_gelu_out
[
i
].
size
(
0
))}
:
std
::
vector
<
size_t
>
{
static_cast
<
size_t
>
(
te_pre_gelu_out
.
size
(
0
)),
:
std
::
vector
<
size_t
>
{
static_cast
<
size_t
>
(
pre_gelu_out
[
i
].
size
(
0
)),
static_cast
<
size_t
>
(
te_pre_gelu_out
.
size
(
1
))};
static_cast
<
size_t
>
(
pre_gelu_out
[
i
].
size
(
1
))};
te_pre_gelu_out
.
emplace_back
(
make_tensor
(
DType
gelu_type
=
bias_type
;
pre_gelu_out
[
i
].
data_ptr
(),
gelu_shape
,
te_pre_gelu_out
=
GetTransformerEngineDType
(
pre_gelu_out
[
i
].
scalar_type
()),
nullptr
,
nullptr
,
nullptr
));
makeTransformerEngineTensor
(
get_data_ptr
(
pre_gelu_out
[
i
]),
gelu_shape
,
gelu_type
);
te_A_vector
.
emplace_back
(
te_A
.
data
());
te_B_vector
.
emplace_back
(
te_B
.
data
());
te_D_vector
.
emplace_back
(
te_D
.
data
());
te_bias_vector
.
emplace_back
(
te_bias
.
data
());
te_pre_gelu_out_vector
.
emplace_back
(
te_pre_gelu_out
.
data
());
wrappers
.
emplace_back
(
std
::
move
(
te_A
));
wrappers
.
emplace_back
(
std
::
move
(
te_B
));
wrappers
.
emplace_back
(
std
::
move
(
te_D
));
wrappers
.
emplace_back
(
std
::
move
(
te_bias
));
wrappers
.
emplace_back
(
std
::
move
(
te_pre_gelu_out
));
}
}
for
(
size_t
i
=
0
;
i
<
workspace
.
size
();
i
++
)
{
for
(
size_t
i
=
0
;
i
<
workspace
.
size
();
i
++
)
{
auto
wsp
=
makeTransformerEngineTensor
(
workspace
[
i
].
data_ptr
(),
{
workspaceSize
},
DType
::
kByte
);
te_workspace
.
emplace_back
(
make_tensor
(
workspace
[
i
].
data_ptr
(),
{
workspaceSize
},
DType
::
kByte
,
te_workspace_vector
.
emplace_back
(
wsp
.
data
());
nullptr
,
nullptr
,
nullptr
));
wrappers
.
emplace_back
(
std
::
move
(
wsp
));
}
}
// For now, we only have multi-stream cublas backend.
nvte_multi_stream_cublas_batchgemm
(
te_A_vector
.
data
(),
te_B_vector
.
data
(),
te_D_vector
.
data
(),
nvte_multi_stream_cublas_batchgemm
(
te_A
.
data
(),
te_B
.
data
(),
te_D
.
data
(),
te_bias
.
data
(),
te_bias_vector
.
data
(),
te_pre_gelu_out_vector
.
data
(),
te_pre_gelu_out
.
data
(),
te_A
.
size
(),
transa
,
transb
,
grad
,
te_A_vector
.
size
(),
transa
,
transb
,
grad
,
te_workspace
.
data
(),
accumulate
,
use_split_accumulator
,
te_workspace_vector
.
data
(),
accumulate
,
use_split_accumulator
,
math_sm_count
,
at
::
cuda
::
getCurrentCUDAStream
());
math_sm_count
,
at
::
cuda
::
getCurrentCUDAStream
());
return
bias
;
}
transformer_engine
::
DType
reverse_map_dtype
(
int64_t
dtype
)
{
if
(
dtype
>=
0
&&
dtype
<
static_cast
<
int64_t
>
(
transformer_engine
::
DType
::
kNumTypes
))
{
return
static_cast
<
transformer_engine
::
DType
>
(
dtype
);
}
else
{
NVTE_ERROR
(
"Type not supported."
);
}
}
std
::
vector
<
at
::
Tensor
>
te_batchgemm_ts
(
std
::
vector
<
at
::
Tensor
>
A
,
at
::
Tensor
A_scale_inverse
,
int64_t
A_offset
,
int64_t
A_type
,
int64_t
transa
,
std
::
vector
<
at
::
Tensor
>
B
,
at
::
Tensor
B_scale_inverse
,
int64_t
B_offset
,
int64_t
B_type
,
int64_t
transb
,
std
::
vector
<
at
::
Tensor
>
D
,
int64_t
D_offset
,
at
::
Tensor
D_scale
,
int64_t
D_type
,
at
::
Tensor
D_amax
,
std
::
vector
<
at
::
Tensor
>
bias
,
int64_t
bias_type
,
std
::
vector
<
at
::
Tensor
>
pre_gelu_out
,
int64_t
grad
,
std
::
vector
<
at
::
Tensor
>
workspace
,
int64_t
workspaceSize
,
int64_t
accumulate
,
int64_t
use_split_accumulator
)
{
using
namespace
transformer_engine
;
using
namespace
transformer_engine
::
pytorch
;
// cast inputs to types accepted by te_gemm
transformer_engine
::
DType
A_type_arg
=
reverse_map_dtype
(
A_type
);
bool
transa_arg
=
static_cast
<
bool
>
(
transa
);
transformer_engine
::
DType
B_type_arg
=
reverse_map_dtype
(
B_type
);
bool
transb_arg
=
static_cast
<
bool
>
(
transb
);
transformer_engine
::
DType
D_type_arg
=
reverse_map_dtype
(
D_type
);
transformer_engine
::
DType
bias_type_arg
=
reverse_map_dtype
(
bias_type
);
bool
grad_arg
=
static_cast
<
bool
>
(
grad
);
size_t
workspaceSize_arg
=
static_cast
<
size_t
>
(
workspaceSize
);
bool
accumulate_arg
=
static_cast
<
bool
>
(
accumulate
);
bool
use_split_accumulator_arg
=
static_cast
<
bool
>
(
use_split_accumulator
);
// Set an external SM Margin to all the GEMMs.
// This comes in handy when DP is overlapped with GEMMs
const
int
sm_count
=
transformer_engine
::
cuda
::
sm_count
();
int
num_math_sms
=
sm_count
-
transformer_engine
::
getenv
<
int
>
(
"NVTE_EXT_MARGIN_SM"
,
sm_count
);
te_batchgemm
(
A
,
A_scale_inverse
,
A_offset
,
A_type_arg
,
transa_arg
,
B
,
B_scale_inverse
,
B_offset
,
B_type_arg
,
transb_arg
,
D
,
D_offset
,
D_scale
,
D_type_arg
,
D_amax
,
bias
,
bias_type_arg
,
pre_gelu_out
,
grad_arg
,
workspace
,
workspaceSize_arg
,
accumulate_arg
,
use_split_accumulator_arg
,
num_math_sms
);
return
D
;
}
}
#endif
#endif
transformer_engine/pytorch/csrc/extensions/pybind.cpp
View file @
9d0f1c9b
...
@@ -175,7 +175,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
...
@@ -175,7 +175,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m
.
def
(
"te_general_grouped_gemm"
,
&
te_general_grouped_gemm
,
"Grouped GEMM"
);
m
.
def
(
"te_general_grouped_gemm"
,
&
te_general_grouped_gemm
,
"Grouped GEMM"
);
#ifdef USE_ROCM
#ifdef USE_ROCM
m
.
def
(
"te_
general_
batch
ed_
gemm"
,
&
te_
general_
batch
ed_
gemm
,
"Batched GEMM"
);
/// rocblas
m
.
def
(
"te_batchgemm
_ts
"
,
&
te_batchgemm
_ts
,
"Batched GEMM"
);
/// rocblas
#endif
#endif
m
.
def
(
"fp8_transpose"
,
&
fp8_transpose
,
"Transpose with FP8 I/O"
,
py
::
arg
(
"input"
),
m
.
def
(
"fp8_transpose"
,
&
fp8_transpose
,
"Transpose with FP8 I/O"
,
py
::
arg
(
"input"
),
py
::
arg
(
"dtype"
),
py
::
kw_only
(),
py
::
arg
(
"out"
),
py
::
call_guard
<
py
::
gil_scoped_release
>
());
py
::
arg
(
"dtype"
),
py
::
kw_only
(),
py
::
arg
(
"out"
),
py
::
call_guard
<
py
::
gil_scoped_release
>
());
...
...
transformer_engine/pytorch/module/batched_linear.py
View file @
9d0f1c9b
...
@@ -2,9 +2,11 @@
...
@@ -2,9 +2,11 @@
#
#
# See LICENSE for license information.
# See LICENSE for license information.
"""BatchedLinear API"""
"""Linear API"""
from
typing
import
Union
,
Optional
,
Callable
,
Tuple
,
List
import
os
import
os
import
logging
from
typing
import
Any
,
Callable
,
Dict
,
Optional
,
Tuple
,
Union
,
List
import
torch
import
torch
import
transformer_engine_torch
as
tex
import
transformer_engine_torch
as
tex
...
@@ -16,7 +18,7 @@ from .base import (
...
@@ -16,7 +18,7 @@ from .base import (
_2X_ACC_DGRAD
,
_2X_ACC_DGRAD
,
_2X_ACC_WGRAD
,
_2X_ACC_WGRAD
,
)
)
from
..fp8
import
FP8GlobalStateManager
from
..fp8
import
get_fp8_te_dtype
,
FP8GlobalStateManager
from
..utils
import
(
from
..utils
import
(
divide
,
divide
,
cast_if_needed
,
cast_if_needed
,
...
@@ -32,27 +34,42 @@ from ..distributed import (
...
@@ -32,27 +34,42 @@ from ..distributed import (
in_fp8_activation_recompute_phase
,
in_fp8_activation_recompute_phase
,
)
)
from
..cpp_extensions
import
(
from
..cpp_extensions
import
(
general_
batch
ed_
gemm
,
batchgemm
,
)
)
from
..constants
import
GemmParallelModes
,
dist_group_type
,
TE_DType
from
..constants
import
GemmParallelModes
,
dist_group_type
from
..jit
import
no_torch_dynamo
from
..jit
import
no_torch_dynamo
from
..graph
import
is_graph_capturing
from
..graph
import
is_graph_capturing
from
..tensor.float8_tensor
import
Float8Tensor
from
..float8_tensor
import
Float8Tensor
from
..cpu_offload
import
is_cpu_offload_enabled
# NVTE_DEBUG = 0/1 # disables/enables debug mode, default = 0
from
..tensor.quantized_tensor
import
(
_NVTE_DEBUG
=
int
(
os
.
getenv
(
"NVTE_DEBUG"
,
"0"
))
QuantizedTensor
,
# NVTE_DEBUG_LEVEL = 0/1/2 # enables more and more verbose debug mode, default = 0
Quantizer
,
_NVTE_DEBUG_LEVEL
=
int
(
os
.
getenv
(
"NVTE_DEBUG_LEVEL"
,
"0"
))
prepare_for_saving
,
log_level
=
_NVTE_DEBUG
*
_NVTE_DEBUG_LEVEL
restore_from_saved
,
log_levels
=
{
0
:
logging
.
WARNING
,
1
:
logging
.
INFO
,
2
:
logging
.
DEBUG
}
logging
.
basicConfig
(
format
=
"[%(levelname)-8s | %(name)-19s]: %(message)s"
,
level
=
log_levels
[
log_level
if
log_level
in
[
0
,
1
,
2
]
else
2
],
)
)
__all__
=
[
"BatchedLinear"
]
__all__
=
[
"BatchedLinear"
]
class
_BatchedLinear
(
torch
.
autograd
.
Function
):
"""
"""BatchedLinear semi-top level module
The offset for fp8_meta_index.
_GEMM_INPUT = 0
_GEMM_WEIGHT = num_gemms
_GEMM_OUTPUT = 2 * num_gemms
Must be properly set in BatchedLinear's initialization.
"""
_GEMM_INPUT
=
0
_GEMM_WEIGHT
=
0
_GEMM_OUTPUT
=
0
_GRAD_OUTPUT
=
0
class
_BatchLinear
(
torch
.
autograd
.
Function
):
"""BatchLinear semi-top level module
Calls custom cuda extensions.
Calls custom cuda extensions.
"""
"""
...
@@ -65,205 +82,137 @@ class _BatchedLinear(torch.autograd.Function):
...
@@ -65,205 +82,137 @@ class _BatchedLinear(torch.autograd.Function):
is_first_microbatch
:
Union
[
bool
,
None
],
is_first_microbatch
:
Union
[
bool
,
None
],
fp8
:
bool
,
fp8
:
bool
,
fp8_calibration
:
bool
,
fp8_calibration
:
bool
,
input_quantizers
:
List
[
Quantizer
],
fp8_meta
:
Dict
[
str
,
Any
],
weight_quantizers
:
List
[
Quantizer
],
output_quantizers
:
List
[
Quantizer
],
grad_output_quantizers
:
List
[
Quantizer
],
fuse_wgrad_accumulation
:
bool
,
fuse_wgrad_accumulation
:
bool
,
cpu_offloading
:
bool
,
cpu_offloading
:
bool
,
tp_group
:
Union
[
dist_group_type
,
None
],
tp_size
:
int
,
sequence_parallel
:
bool
,
sequence_parallel
:
bool
,
tensor_parallel
:
bool
,
activation_dtype
:
torch
.
dtype
,
activation_dtype
:
torch
.
dtype
,
parallel_mode
:
Union
[
str
,
None
],
is_grad_enabled
:
bool
,
is_grad_enabled
:
bool
,
module
,
*
weights_and_biases
:
Union
[
Float8Tensor
,
torch
.
Tensor
,
None
],
skip_fp8_weight_update
,
*
weights_and_biases
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
batch_num
=
int
(
os
.
getenv
(
"NVTE_MOE_BATCHCOUNT"
,
"2"
))
batch_num
=
int
(
os
.
getenv
(
"NVTE_MOE_BATCHCOUNT"
,
"2"
))
logger
=
logging
.
getLogger
(
"BatchLinear"
)
# pylint: disable=missing-function-docstring
num_gemms
=
len
(
m_splits
)
num_gemms
=
len
(
m_splits
)
weights
=
weights_and_biases
[:
num_gemms
]
weights
=
weights_and_biases
[:
num_gemms
]
biases
=
weights_and_biases
[
num_gemms
:]
weights_fp8
=
weights_and_biases
[
num_gemms
:
2
*
num_gemms
]
device
=
inp
.
device
biases
=
weights_and_biases
[
2
*
num_gemms
:]
# TODO Support MXFP8 # pylint: disable=fixme
if
fp8
and
FP8GlobalStateManager
.
get_fp8_recipe
().
mxfp8
():
raise
NotImplementedError
(
"BatchedLinear does not yet support MXFP8"
)
# TODO Support Float8 Current Scaling # pylint: disable=fixme
if
fp8
and
FP8GlobalStateManager
.
get_fp8_recipe
().
float8_current_scaling
():
raise
NotImplementedError
(
"BatchedLinear does not yet support Float8 Current Scaling"
)
# TODO Support Float8 Delayed Scaling # pylint: disable=fixme
if
fp8
and
FP8GlobalStateManager
.
get_fp8_recipe
().
delayed
():
raise
NotImplementedError
(
"BatchedLinear does not yet support Float8 Delayed Scaling"
)
# TODO Support Float8 Per Tensor Scaling # pylint: disable=fixme
if
fp8
and
FP8GlobalStateManager
.
get_fp8_recipe
().
float8_per_tensor_scaling
():
raise
NotImplementedError
(
"BatchedLinear does not yet support Float8 Per Tensor Scaling"
)
# Make sure input dimensions are compatible
# Make sure input dimensions are compatible
in_features
=
weights
[
0
].
shape
[
-
1
]
in_features
=
weights
[
0
].
shape
[
-
1
]
assert
inp
.
shape
[
-
1
]
==
in_features
,
"GEMM not possible"
assert
inp
.
shape
[
-
1
]
==
in_features
,
"GEMM not possible"
inputmats
=
torch
.
split
(
inp
.
view
(
-
1
,
in_features
),
m_splits
)
inputmats
=
torch
.
split
(
inp
.
view
(
-
1
,
in_features
),
m_splits
)
if
fp8
:
if
fp8
:
assert
_dim_for_fp8_exec
(
*
inputmats
,
*
weights
)
assert
False
,
"BatchLinear does not support fp8 yet."
# Cast input to expected dtype
# Cast input to expected dtype
inputmats_no_fp8
=
[
cast_if_needed
(
mat
,
activation_dtype
)
for
mat
in
inputmats
]
inputmats_no_fp8
=
[
cast_if_needed
(
mat
,
activation_dtype
)
for
mat
in
inputmats
]
inputmats
=
[]
inputmats
=
[]
inputmats_t
=
[]
global
_GEMM_INPUT
,
_GEMM_WEIGHT
,
_GEMM_OUTPUT
inputmats
=
inputmats_no_fp8
weight_requires_grad
=
weights
[
0
].
requires_grad
logger
.
debug
(
"Running forward in %s"
,
activation_dtype
)
if
input_quantizers
[
0
]
is
not
None
:
for
input_quantizer
in
input_quantizers
:
input_quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
(
is_grad_enabled
and
weight_requires_grad
),
)
columnwise_usage
=
is_grad_enabled
and
inp
.
requires_grad
if
not
columnwise_usage
:
columnwise_usage
=
(
is_fp8_activation_recompute_enabled
()
and
not
in_fp8_activation_recompute_phase
()
)
if
weight_quantizers
[
0
]
is
not
None
:
for
weight_quantizer
in
weight_quantizers
:
weight_quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
columnwise_usage
)
if
output_quantizers
[
0
]
is
not
None
:
for
output_quantizer
in
output_quantizers
:
output_quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
False
)
if
fp8
:
inputmats
=
tex
.
fused_multi_quantize
(
inputmats_no_fp8
,
None
,
input_quantizers
,
TE_DType
[
activation_dtype
]
)
weights_fp8
=
[]
bias_dtype
=
torch
.
bfloat16
if
activation_dtype
==
torch
.
float32
else
activation_dtype
if
not
isinstance
(
weights
[
0
],
QuantizedTensor
):
# FP8 cast to workspace buffer
update_workspace
=
is_first_microbatch
is
None
or
is_first_microbatch
for
i
in
range
(
num_gemms
):
weight_fp8
=
module
.
get_weight_workspace
(
tensor
=
weights
[
i
],
quantizer
=
weight_quantizers
[
i
],
cache_name
=
(
None
if
is_first_microbatch
is
None
else
f
"weight
{
i
}
"
),
update_workspace
=
update_workspace
,
skip_update_flag
=
skip_fp8_weight_update
,
)
weights_fp8
.
append
(
weight_fp8
)
else
:
weights_fp8
=
weights
else
:
inputmats
=
inputmats_no_fp8
bias_dtype
=
activation_dtype
weights_fp8
=
[
cast_if_needed
(
weight
,
activation_dtype
)
for
weight
in
weights
]
biases
=
[
cast_if_needed
(
bias
,
bias_dtype
)
for
bias
in
biases
]
if
use_bias
else
biases
# Cast for native AMP
weights
=
[
cast_if_needed
(
w
,
activation_dtype
)
for
w
in
weights
]
biases
=
(
[
cast_if_needed
(
bias
,
activation_dtype
)
for
bias
in
biases
]
if
use_bias
else
biases
)
assert
weights
[
0
].
size
(
0
)
%
batch_num
==
0
,
"weights[0].size(0) should be batch_num multiply."
assert
weights_fp8
[
0
].
size
(
0
)
%
batch_num
==
0
,
"weights_fp8[0].size(0) should be batch_num multiply."
out
=
torch
.
empty
(
out
=
torch
.
empty
(
[
sum
(
m_splits
),
weights
_fp8
[
0
].
size
(
0
)
//
batch_num
],
[
sum
(
m_splits
),
int
(
weights
[
0
].
size
(
0
)
//
batch_num
)
],
dtype
=
activation_dtype
,
dtype
=
activation_dtype
,
device
=
device
,
device
=
inputmats
[
0
].
device
,
)
)
_
=
batchgemm
(
_
=
general_batched_gemm
(
weights
,
weights_fp8
,
inputmats
,
inputmats
,
[
out
]
,
torch
.
split
(
out
,
m_splits
)
,
activation_dtype
,
activation_dtype
,
get_multi_stream_cublas_batchgemm_workspace
(),
get_multi_stream_cublas_batchgemm_workspace
(),
single_output
=
True
,
m_splits
=
m_splits
,
bias
=
biases
,
bias
=
biases
,
use_bias
=
use_bias
,
use_bias
=
use_bias
,
use_split_accumulator
=
_2X_ACC_FPROP
,
)
)
if
fp8_calibration
:
for
i
in
range
(
num_gemms
):
# amax of input
for
i
in
range
(
num_gemms
):
input_quantizers
[
i
].
calibrate
(
inputmats
[
i
])
for
i
in
range
(
num_gemms
):
weight_quantizers
[
i
].
calibrate
(
weights
[
i
])
if
is_grad_enabled
:
if
is_grad_enabled
:
saved_inputmats
=
[
None
]
*
num_gemms
ctx
.
weights_shape_1
=
weights
[
0
].
shape
[
1
]
saved_inputmats_t
=
[
None
]
*
num_gemms
if
weights
[
0
].
requires_grad
:
tensors_to_save
,
tensor_objects
=
prepare_for_saving
(
*
inputmats
,
*
weights_fp8
,
*
biases
)
saved_inputmats
=
inputmats_no_fp8
ctx
.
save_for_backward
(
*
tensors_to_save
)
ctx
.
tensor_objects
=
tensor_objects
if
cpu_offloading
:
if
fuse_wgrad_accumulation
:
ctx
.
weights_requires_grad
=
weights
[
0
].
requires_grad
for
w
in
weights
:
if
fuse_wgrad_accumulation
and
ctx
.
weights_requires_grad
:
w
.
main_grad
.
weight_offloading
=
True
ctx
.
main_grads
=
[
weights
[
i
].
main_grad
for
i
in
range
(
num_gemms
)]
for
w
in
weights
:
else
:
w
.
weight_offloading
=
True
ctx
.
main_grads
=
[
None
]
*
num_gemms
for
t
in
saved_inputmats
:
ctx
.
device
=
device
if
t
is
not
None
:
ctx
.
grad_output_quantizers
=
grad_output_quantizers
t
.
activation_offloading
=
True
ctx
.
save_for_backward
(
None
,
*
saved_inputmats
,
*
saved_inputmats_t
,
*
weights
,
*
weights_fp8
,
*
[
w
.
main_grad
if
cpu_offloading
and
fuse_wgrad_accumulation
else
None
for
w
in
weights
],
)
ctx
.
m_splits
=
m_splits
ctx
.
m_splits
=
m_splits
ctx
.
num_gemms
=
num_gemms
ctx
.
num_gemms
=
num_gemms
ctx
.
activation_dtype
=
activation_dtype
ctx
.
activation_dtype
=
activation_dtype
ctx
.
fp8
=
fp8
ctx
.
fp8
=
fp8
ctx
.
fp8_meta
=
fp8_meta
ctx
.
fuse_wgrad_accumulation
=
fuse_wgrad_accumulation
ctx
.
fuse_wgrad_accumulation
=
fuse_wgrad_accumulation
ctx
.
cpu_offloading
=
cpu_offloading
ctx
.
cpu_offloading
=
cpu_offloading
ctx
.
is_first_microbatch
=
is_first_microbatch
ctx
.
is_first_microbatch
=
is_first_microbatch
ctx
.
use_bias
=
use_bias
ctx
.
use_bias
=
use_bias
ctx
.
sequence_parallel
=
sequence_parallel
ctx
.
sequence_parallel
=
sequence_parallel
ctx
.
tensor_parallel
=
tensor_parallel
ctx
.
inp_shape
=
inp
.
shape
ctx
.
inp_shape
=
inp
.
shape
ctx
.
parallel_mode
=
parallel_mode
ctx
.
tp_group
=
tp_group
ctx
.
tp_size
=
tp_size
ctx
.
requires_dgrad
=
inp
.
requires_grad
ctx
.
requires_dgrad
=
inp
.
requires_grad
ctx
.
reduce_and_update_bwd_fp8_tensors
=
False
ctx
.
reduce_and_update_bwd_fp8_tensors
=
False
if
ctx
.
fp8
and
requires_grad
(
inp
,
weights
[
0
],
biases
[
0
]):
ctx
.
reduce_and_update_bwd_fp8_tensors
=
(
ctx
.
reduce_and_update_bwd_fp8_tensors
or
FP8GlobalStateManager
.
is_first_fp8_module
()
)
# [*, in_features] -> [*, out_features] except first dimension changes for SP
# [*, in_features] -> [*, out_features] except first dimension changes for SP
return
out
.
view
(
-
1
,
*
inp
.
shape
[
1
:
-
1
],
out
.
shape
[
-
1
])
return
out
.
view
(
-
1
,
*
inp
.
shape
[
1
:
-
1
],
out
.
shape
[
-
1
])
@
staticmethod
@
staticmethod
def
backward
(
ctx
,
grad_output
:
torch
.
Tensor
)
->
Tuple
[
Union
[
torch
.
Tensor
,
None
],
...]:
def
backward
(
ctx
,
grad_output
:
torch
.
Tensor
)
->
Tuple
[
Union
[
torch
.
Tensor
,
None
],
...]:
# pylint: disable=missing-function-docstring
logger
=
logging
.
getLogger
(
"BatchLinear"
)
with
torch
.
cuda
.
nvtx
.
range
(
"_BatchedLinear_backward"
):
saved_tensors
=
restore_from_saved
(
ctx
.
tensor_objects
,
ctx
.
saved_tensors
)
with
torch
.
cuda
.
nvtx
.
range
(
"_BatchLinear_backward"
):
N
=
ctx
.
num_gemms
(
inputmats
=
saved_tensors
[:
N
]
fwd_scale_inverses
,
weights
=
saved_tensors
[
N
:
2
*
N
]
*
saved_tensors
,
biases
=
saved_tensors
[
2
*
N
:
3
*
N
]
)
=
ctx
.
saved_tensors
main_grads
=
ctx
.
main_grads
inputmats
=
saved_tensors
[:
ctx
.
num_gemms
]
inputmats_t
=
saved_tensors
[
ctx
.
num_gemms
:
2
*
ctx
.
num_gemms
]
if
ctx
.
cpu_offloading
and
ctx
.
fuse_wgrad_accumulation
:
# TOSO
weights
=
saved_tensors
[
2
*
ctx
.
num_gemms
:
3
*
ctx
.
num_gemms
]
weights_fp8
=
saved_tensors
[
3
*
ctx
.
num_gemms
:
4
*
ctx
.
num_gemms
]
main_grads
=
saved_tensors
[
4
*
ctx
.
num_gemms
:]
if
ctx
.
cpu_offloading
and
ctx
.
fuse_wgrad_accumulation
:
for
i
in
ctx
.
num_gemms
:
for
i
in
ctx
.
num_gemms
:
w
=
torch
.
nn
.
Parameter
(
weights
[
i
],
weights
[
i
].
requires_grad
)
w
=
torch
.
nn
.
Parameter
(
weights
[
i
],
False
)
w
.
main_grad
=
main_grads
[
i
]
w
.
main_grad
=
main_grads
[
i
]
weights
[
i
]
=
w
weights
[
i
]
=
w
global
_GEMM_INPUT
,
_GEMM_WEIGHT
,
_GRAD_OUTPUT
# preprocess grad_output
grad_output
=
grad_output
.
contiguous
()
grad_output
=
grad_output
.
contiguous
()
grad_output_mats
=
torch
.
split
(
grad_output_mats
=
torch
.
split
(
grad_output
.
view
(
-
1
,
grad_output
.
shape
[
-
1
]),
ctx
.
m_splits
grad_output
.
view
(
-
1
,
grad_output
.
shape
[
-
1
]),
ctx
.
m_splits
)
)
grad_output
=
[
None
]
*
ctx
.
num_gemms
grad_output_c
=
[
None
]
*
ctx
.
num_gemms
grad_output_t
=
[
None
]
*
ctx
.
num_gemms
grad_biases
=
[
None
]
*
ctx
.
num_gemms
grad_biases
=
[
None
]
*
ctx
.
num_gemms
if
ctx
.
fp8
:
if
ctx
.
use_bias
:
for
i
in
range
(
ctx
.
num_gemms
):
grad_biases
[
i
],
grad_output
[
i
]
=
tex
.
bgrad_quantize
(
grad_output_mats
[
i
],
ctx
.
grad_output_quantizers
[
i
]
)
else
:
grad_output
=
tex
.
fused_multi_quantize
(
grad_output_mats
,
None
,
ctx
.
grad_output_quantizers
,
TE_DType
[
ctx
.
activation_dtype
],
)
else
:
grad_output
=
grad_output_mats
if
ctx
.
is_first_microbatch
is
not
None
:
if
ctx
.
is_first_microbatch
is
not
None
:
accumulate_wgrad_into_param_main_grad
=
(
accumulate_wgrad_into_param_main_grad
=
(
...
@@ -273,114 +222,105 @@ class _BatchedLinear(torch.autograd.Function):
...
@@ -273,114 +222,105 @@ class _BatchedLinear(torch.autograd.Function):
accumulate_wgrad_into_param_main_grad
=
ctx
.
fuse_wgrad_accumulation
accumulate_wgrad_into_param_main_grad
=
ctx
.
fuse_wgrad_accumulation
if
ctx
.
requires_dgrad
:
if
ctx
.
requires_dgrad
:
logger
.
debug
(
"Running backward in %s"
,
ctx
.
activation_dtype
)
dgrad
=
torch
.
empty
(
dgrad
=
torch
.
empty
(
(
sum
(
ctx
.
m_splits
),
ctx
.
weights
_shape_1
),
(
sum
(
ctx
.
m_splits
),
int
(
weights
[
0
].
size
(
1
))
),
dtype
=
ctx
.
activation_dtype
,
dtype
=
ctx
.
activation_dtype
,
device
=
ctx
.
device
,
device
=
grad_output
.
device
,
)
)
batchgemm
(
general_batched_gemm
(
weights
,
weights
,
grad_output
,
grad_output
_mats
,
[
dgrad
]
,
torch
.
split
(
dgrad
,
ctx
.
m_splits
)
,
ctx
.
activation_dtype
,
ctx
.
activation_dtype
,
get_multi_stream_cublas_batchgemm_workspace
(),
get_multi_stream_cublas_batchgemm_workspace
(),
single_output
=
True
,
layout
=
"NN"
,
layout
=
"NN"
,
m_splits
=
ctx
.
m_splits
,
grad
=
True
,
grad
=
True
,
use_split_accumulator
=
_2X_ACC_DGRAD
,
)
)
if
ctx
.
weights
_
requires_grad
:
if
weights
[
0
].
requires_grad
:
if
ctx
.
fuse_wgrad_accumulation
:
if
ctx
.
fuse_wgrad_accumulation
:
wgrad_list
=
main_grad
s
wgrad_list
=
[
w
.
main_grad
for
w
in
weights
]
else
:
else
:
wgrad_list
=
[
wgrad_list
=
[
torch
.
empty
(
w
.
size
(),
dtype
=
ctx
.
activation_dtype
,
device
=
ctx
.
device
)
torch
.
empty
(
w
.
size
(),
dtype
=
ctx
.
activation_dtype
,
device
=
w
.
device
)
for
w
in
weights
for
w
in
weights
]
]
# WGRAD
# WGRAD
_
,
grad_biases
_
,
_
=
general_
batch
ed_
gemm
(
_
,
grad_biases
,
_
=
batchgemm
(
inputmats
,
inputmats
,
grad_output
,
grad_output
_mats
,
wgrad_list
,
wgrad_list
,
ctx
.
activation_dtype
,
ctx
.
activation_dtype
,
get_multi_stream_cublas_batchgemm_workspace
(),
get_multi_stream_cublas_batchgemm_workspace
(),
layout
=
"NT"
,
layout
=
"NT"
,
grad
=
True
,
grad
=
True
,
m_splits
=
ctx
.
m_splits
,
use_bias
=
ctx
.
use_bias
,
use_bias
=
ctx
.
use_bias
if
grad_biases
[
0
]
is
None
else
None
,
bias
=
biases
,
use_split_accumulator
=
_2X_ACC_WGRAD
,
accumulate
=
accumulate_wgrad_into_param_main_grad
,
accumulate
=
accumulate_wgrad_into_param_main_grad
,
)
)
for
i
in
range
(
ctx
.
num_gemms
):
if
grad_biases
[
i
]
is
None
:
grad_biases
[
i
]
=
grad_biases_
[
i
]
del
grad_biases_
# Deallocate input tensor
# Deallocate input tensor
clear_tensor_data
(
*
inputmats
)
clear_tensor_data
(
*
inputmats
)
clear_tensor_data
(
*
inputmats_t
)
def
handle_custom_ddp_from_mcore
(
w
,
wgrad
):
if
not
ctx
.
use_bias
:
if
ctx
.
weights_requires_grad
:
grad_biases
=
[
None
]
*
ctx
.
num_gemms
if
ctx
.
fuse_wgrad_accumulation
and
hasattr
(
w
,
"grad_added_to_main_grad"
):
w
.
grad_added_to_main_grad
=
True
if
getattr
(
w
,
"zero_out_wgrad"
,
False
):
wgrad
=
torch
.
zeros
(
w
.
main_grad
.
shape
,
dtype
=
w
.
dtype
,
device
=
torch
.
cuda
.
current_device
(),
requires_grad
=
False
,
)
else
:
wgrad
=
torch
.
empty
(
w
.
main_grad
.
shape
,
dtype
=
w
.
dtype
,
device
=
torch
.
cuda
.
current_device
(),
requires_grad
=
False
,
)
elif
ctx
.
fuse_wgrad_accumulation
:
wgrad
=
None
else
:
wgrad
=
None
return
wgrad
wgrad_list
=
[
def
handle_custom_ddp_from_mcore
(
w
,
wgrad
):
handle_custom_ddp_from_mcore
(
w
,
wgrad
)
for
w
,
wgrad
in
zip
(
weights
,
wgrad_list
)
if
w
.
requires_grad
:
]
if
ctx
.
fuse_wgrad_accumulation
and
hasattr
(
w
,
"grad_added_to_main_grad"
):
w
.
grad_added_to_main_grad
=
True
if
getattr
(
w
,
"zero_out_wgrad"
,
False
):
wgrad
=
torch
.
zeros
(
w
.
main_grad
.
shape
,
dtype
=
w
.
dtype
,
device
=
torch
.
cuda
.
current_device
(),
requires_grad
=
False
,
)
else
:
wgrad
=
torch
.
empty
(
w
.
main_grad
.
shape
,
dtype
=
w
.
dtype
,
device
=
torch
.
cuda
.
current_device
(),
requires_grad
=
False
,
)
elif
ctx
.
fuse_wgrad_accumulation
:
wgrad
=
None
else
:
else
:
wgrad_list
=
[
None
]
*
ctx
.
num_gemms
wgrad
=
None
return
wgrad
if
not
ctx
.
use_bias
:
wgrad_list
=
[
grad_biases
=
[
None
]
*
ctx
.
num_gemms
handle_custom_ddp_from_mcore
(
w
,
wgrad
)
for
w
,
wgrad
in
zip
(
weights
,
wgrad_list
)
]
if
ctx
.
reduce_and_update_bwd_fp8_tensors
and
not
is_graph_capturing
():
if
ctx
.
reduce_and_update_bwd_fp8_tensors
and
not
is_graph_capturing
():
FP8GlobalStateManager
.
reduce_and_update_fp8_tensors
(
forward
=
False
)
FP8GlobalStateManager
.
reduce_and_update_fp8_tensors
(
forward
=
False
)
return
(
return
(
dgrad
.
view
(
ctx
.
inp_shape
)
if
ctx
.
requires_dgrad
else
None
,
dgrad
.
view
(
ctx
.
inp_shape
)
if
ctx
.
requires_dgrad
else
None
,
None
,
None
,
# m_splits
None
,
None
,
# use_bias
None
,
None
,
# is_first_microbatch
None
,
None
,
# fp8
None
,
None
,
# fp8_calibration
None
,
None
,
# fp8_meta
None
,
None
,
# fuse_wgrad_accumulation
None
,
None
,
# cpu_offloading
None
,
None
,
# tp_group
None
,
None
,
# tp_size
None
,
None
,
# sequence_parallel
None
,
None
,
# tensor_parallel
None
,
None
,
# activation_dtype
None
,
None
,
# parallel_mode
None
,
# is_grad_enabled
None
,
# is_grad_enabled
None
,
# is_grad_enabled
*
wgrad_list
,
*
wgrad_list
,
*
([
None
]
*
ctx
.
num_gemms
),
# weights_fp8
*
grad_biases
,
*
grad_biases
,
)
)
class
BatchedLinear
(
TransformerEngineBaseModule
):
class
BatchedLinear
(
TransformerEngineBaseModule
):
"""Applies linear transformations to the incoming data list
"""Applies linear transformations to the incoming data list
:math:`y_i = x_iA_i^T + b_i` in a batched way.
:math:`y_i = x_iA_i^T + b_i` in a batched way.
...
@@ -399,14 +339,31 @@ class BatchedLinear(TransformerEngineBaseModule):
...
@@ -399,14 +339,31 @@ class BatchedLinear(TransformerEngineBaseModule):
used for initializing weights in the following way: `init_method(weight)`.
used for initializing weights in the following way: `init_method(weight)`.
When set to `None`, defaults to `torch.nn.init.normal_(mean=0.0, std=0.023)`.
When set to `None`, defaults to `torch.nn.init.normal_(mean=0.0, std=0.023)`.
get_rng_state_tracker : Callable, default = `None`
get_rng_state_tracker : Callable, default = `None`
used to get the random number generator state tracker for initi
a
lizing weights.
used to get the random number generator state tracker for initiliz
e
ing weights.
rng_tracker_name : str, default = `None`
rng_tracker_name : str, default = `None`
the param passed to get_rng_state_tracker to get the specific rng tracker.
the param passed to get_rng_state_tracker to get the specific rng tracker.
device : Union[torch.device, str], default = "cuda"
device : Union[torch.device, str], default = "cuda"
The device on which the parameters of the model will
be
allocated. It is the user's
The device on which the parameters of the model will allocated. It is the user's
responsibility to ensure all parameters are moved to the GPU before running the
responsibility to ensure all parameters are moved to the GPU before running the
forward pass.
forward pass.
Parallelism parameters
----------------------
sequence_parallel : bool, default = `False`
if set to `True`, uses sequence parallelism.
tp_group : ProcessGroup, default = `None`
tensor parallel process group.
tp_size : int, default = 1
used as TP (tensor parallel) world size when TP groups are not formed during
initialization. In this case, users must call the
`set_tensor_parallel_group(tp_group)` method on the initialized module before the
forward pass to supply the tensor parallel group needed for tensor and sequence
parallel collectives.
parallel_mode : {None, 'Column', 'Row'}, default = `None`
used to decide whether this BatchedLinear layer is Column Parallel Linear or Row
Parallel Linear as described `here <https://arxiv.org/pdf/1909.08053.pdf>`_.
When set to `None`, no communication is performed.
Optimization parameters
Optimization parameters
-----------------------
-----------------------
fuse_wgrad_accumulation : bool, default = 'False'
fuse_wgrad_accumulation : bool, default = 'False'
...
@@ -426,7 +383,6 @@ class BatchedLinear(TransformerEngineBaseModule):
...
@@ -426,7 +383,6 @@ class BatchedLinear(TransformerEngineBaseModule):
would not fit in GPU memory.
would not fit in GPU memory.
"""
"""
def
__init__
(
def
__init__
(
self
,
self
,
num_gemms
:
int
,
num_gemms
:
int
,
...
@@ -462,15 +418,15 @@ class BatchedLinear(TransformerEngineBaseModule):
...
@@ -462,15 +418,15 @@ class BatchedLinear(TransformerEngineBaseModule):
self
.
apply_bias
=
bias
and
not
return_bias
self
.
apply_bias
=
bias
and
not
return_bias
self
.
ub_overlap_rs
=
ub_overlap_rs
self
.
ub_overlap_rs
=
ub_overlap_rs
self
.
ub_overlap_ag
=
ub_overlap_ag
self
.
ub_overlap_ag
=
ub_overlap_ag
if
ub_overlap_rs
or
ub_overlap_ag
:
assert
ub_name
is
not
None
,
"Userbuffer name [string] is not set."
self
.
ub_name
=
ub_name
self
.
ub_name
=
ub_name
assert
(
not
ub_overlap_rs
and
not
ub_overlap_ag
),
"BatchedLinear doesn't support Userbuffer overlap."
self
.
get_rng_state_tracker
=
get_rng_state_tracker
self
.
get_rng_state_tracker
=
get_rng_state_tracker
self
.
rng_tracker_name
=
rng_tracker_name
self
.
rng_tracker_name
=
rng_tracker_name
self
.
_offsets
=
{
"input"
:
0
,
"weight"
:
self
.
num_gemms
,
"output"
:
2
*
self
.
num_gemms
,
"grad_output"
:
0
}
global
_GEMM_INPUT
,
_GEMM_WEIGHT
,
_GEMM_OUTPUT
_GEMM_INPUT
,
_GEMM_WEIGHT
,
_GEMM_OUTPUT
=
0
,
self
.
num_gemms
,
2
*
self
.
num_gemms
if
tp_group
is
None
:
if
tp_group
is
None
:
self
.
tp_size
=
tp_size
self
.
tp_size
=
tp_size
if
tp_size
==
1
:
if
tp_size
==
1
:
...
@@ -492,7 +448,7 @@ class BatchedLinear(TransformerEngineBaseModule):
...
@@ -492,7 +448,7 @@ class BatchedLinear(TransformerEngineBaseModule):
self
.
sequence_parallel
=
(
self
.
tp_size
>
1
)
and
sequence_parallel
self
.
sequence_parallel
=
(
self
.
tp_size
>
1
)
and
sequence_parallel
# In batchgemm, we use batch=batch_num to launch blas batchgemm
# In batchgemm, we use batch=batch_num to launch blas batchgemm
for
i
in
range
(
self
.
num_gemms
):
for
i
in
range
(
int
(
self
.
num_gemms
)
)
:
# Construct weight parameter
# Construct weight parameter
self
.
register_parameter
(
self
.
register_parameter
(
f
"weight
{
i
}
"
,
f
"weight
{
i
}
"
,
...
@@ -506,7 +462,7 @@ class BatchedLinear(TransformerEngineBaseModule):
...
@@ -506,7 +462,7 @@ class BatchedLinear(TransformerEngineBaseModule):
),
),
init_fn
=
init_method
,
init_fn
=
init_method
,
get_rng_state_tracker
=
get_rng_state_tracker
,
get_rng_state_tracker
=
get_rng_state_tracker
,
fp8_meta_index
=
self
.
_offsets
[
"weight"
]
+
i
,
fp8_meta_index
=
_GEMM_WEIGHT
+
i
,
)
)
# Construct bias parameters if needed
# Construct bias parameters if needed
...
@@ -515,7 +471,7 @@ class BatchedLinear(TransformerEngineBaseModule):
...
@@ -515,7 +471,7 @@ class BatchedLinear(TransformerEngineBaseModule):
f
"bias
{
i
}
"
,
f
"bias
{
i
}
"
,
torch
.
nn
.
Parameter
(
torch
.
nn
.
Parameter
(
torch
.
empty
(
torch
.
empty
(
self
.
out_features
,
self
.
out_features
*
self
.
batch_num
,
device
=
device
,
device
=
device
,
dtype
=
params_dtype
,
dtype
=
params_dtype
,
),
),
...
@@ -525,11 +481,15 @@ class BatchedLinear(TransformerEngineBaseModule):
...
@@ -525,11 +481,15 @@ class BatchedLinear(TransformerEngineBaseModule):
else
:
else
:
bias
=
torch
.
Tensor
().
to
(
dtype
=
params_dtype
,
device
=
device
)
bias
=
torch
.
Tensor
().
to
(
dtype
=
params_dtype
,
device
=
device
)
setattr
(
self
,
f
"bias
{
i
}
"
,
bias
)
setattr
(
self
,
f
"bias
{
i
}
"
,
bias
)
if
self
.
primary_weights_in_fp8
:
self
.
init_fp8_metadata
(
num_gemms
=
self
.
num_gemms
)
if
self
.
primary_weights_in_fp8
:
if
self
.
primary_weights_in_fp8
:
self
.
init_fp8_metadata
(
num_gemms
=
self
.
num_gemms
)
self
.
init_fp8_metadata
(
num_gemms
=
self
.
num_gemms
)
self
.
reset_parameters
(
defer_init
=
device
==
"meta"
)
self
.
reset_parameters
(
defer_init
=
(
device
==
"meta"
))
# For RPL, bias has to be added after TP collectives
# For RPL, bias has to be added after TP collectives
# So it cannot be fused with the GEMM
# So it cannot be fused with the GEMM
...
@@ -543,7 +503,7 @@ class BatchedLinear(TransformerEngineBaseModule):
...
@@ -543,7 +503,7 @@ class BatchedLinear(TransformerEngineBaseModule):
if
not
defer_init
:
if
not
defer_init
:
# Set parallelism attributes for linear weights
# Set parallelism attributes for linear weights
for
i
in
range
(
self
.
num_gemms
):
for
i
in
range
(
int
(
self
.
num_gemms
)
)
:
set_tensor_model_parallel_attributes
(
set_tensor_model_parallel_attributes
(
tensor
=
getattr
(
self
,
f
"weight
{
i
}
"
),
tensor
=
getattr
(
self
,
f
"weight
{
i
}
"
),
is_parallel
=
True
,
is_parallel
=
True
,
...
@@ -553,15 +513,15 @@ class BatchedLinear(TransformerEngineBaseModule):
...
@@ -553,15 +513,15 @@ class BatchedLinear(TransformerEngineBaseModule):
# Set parallelism attributes for linear biases
# Set parallelism attributes for linear biases
if
self
.
use_bias
:
if
self
.
use_bias
:
for
bias
in
self
.
bias_names
:
for
i
in
range
(
self
.
num_gemms
)
:
if
self
.
parallel_mode
==
"row"
:
if
self
.
parallel_mode
==
"row"
:
setattr
(
setattr
(
getattr
(
self
,
bias
),
getattr
(
self
,
f
"
bias
{
i
}
"
),
"sequence_parallel"
,
"sequence_parallel"
,
self
.
sequence_parallel
,
self
.
sequence_parallel
,
)
)
elif
self
.
parallel_mode
==
"column"
:
elif
self
.
parallel_mode
==
"column"
:
set_tensor_model_parallel_attributes
(
getattr
(
self
,
bias
),
True
,
0
,
1
)
set_tensor_model_parallel_attributes
(
getattr
(
self
,
f
"
bias
{
i
}
"
),
True
,
0
,
1
)
@
no_torch_dynamo
()
@
no_torch_dynamo
()
def
forward
(
def
forward
(
...
@@ -593,57 +553,33 @@ class BatchedLinear(TransformerEngineBaseModule):
...
@@ -593,57 +553,33 @@ class BatchedLinear(TransformerEngineBaseModule):
first microbatch (since it is the first gradient being
first microbatch (since it is the first gradient being
produced)
produced)
"""
"""
assert
not
isinstance
(
assert
not
isinstance
(
inp
,
Float8Tensor
inp
,
Float8Tensor
),
"BatchedLinear doesn't support input tensor in FP8."
),
"BatchedLinear doesn't support input tensor in FP8."
m_splits_batch_gemm
=
[
x
*
self
.
batch_num
for
x
in
m_splits
[
0
:
int
(
self
.
num_gemms
)]]
m_splits_batch_gemm
=
[
x
*
self
.
batch_num
for
x
in
m_splits
[
0
:
int
(
self
.
num_gemms
)]]
assert
len
(
m_splits_batch_gemm
)
==
self
.
num_gemms
,
"Number of splits should match number of GEMMs."
assert
len
(
m_splits_batch_gemm
)
==
self
.
num_gemms
,
"Number of splits should match number of GEMMs."
skip_fp8_weight_update
=
FP8GlobalStateManager
.
get_skip_fp8_weight_update_tensor
()
skip_fp8_weight_update
=
FP8GlobalStateManager
.
get_skip_fp8_weight_update_tensor
()
if
skip_fp8_weight_update
is
not
None
:
if
skip_fp8_weight_update
is
not
None
:
is_first_microbatch
=
False
is_first_microbatch
=
False
with
self
.
prepare_forward
(
inp
=
inp
,
num_gemms
=
self
.
num_gemms
)
as
inp
:
with
self
.
prepare_forward
(
inp
,
num_gemms
=
self
.
num_gemms
)
as
inp
:
weight_tensors
=
[
getattr
(
self
,
f
"weight
{
i
}
"
)
for
i
in
range
(
int
(
self
.
num_gemms
))]
bias_tensors
=
[
getattr
(
self
,
f
"bias
{
i
}
"
)
for
i
in
range
(
int
(
self
.
num_gemms
))]
weight_tensors
=
[
getattr
(
self
,
f
"weight
{
i
}
"
)
for
i
in
range
(
self
.
num_gemms
)]
bias_tensors
=
[
getattr
(
self
,
f
"bias
{
i
}
"
)
for
i
in
range
(
self
.
num_gemms
)]
if
not
self
.
fp8
:
if
not
self
.
fp8
:
weight_tensors
=
[
weight_tensors
=
[
w
.
dequantize
()
if
isinstance
(
w
,
Quantized
Tensor
)
else
w
for
w
in
weight_tensors
w
.
from_float8
()
if
isinstance
(
w
,
Float8
Tensor
)
else
w
for
w
in
weight_tensors
]
]
input_quantizers
,
weight_quantizers
,
output_quantizers
=
(
weight_tensors_fp8
=
[
None
]
*
int
(
self
.
num_gemms
)
[
None
]
*
self
.
num_gemms
,
[
None
]
*
self
.
num_gemms
,
from
..cpu_offload
import
CPUOffloadEnabled
[
None
]
*
self
.
num_gemms
,
)
grad_output_quantizers
,
_
=
[
None
]
*
self
.
num_gemms
,
[
None
]
*
self
.
num_gemms
if
self
.
fp8
:
input_quantizers
=
[
self
.
quantizers
[
"scaling_fwd"
][
self
.
_offsets
[
"input"
]
+
i
]
for
i
in
range
(
self
.
num_gemms
)
]
for
i
in
range
(
self
.
num_gemms
):
input_quantizers
[
i
].
internal
=
True
weight_quantizers
=
[
self
.
quantizers
[
"scaling_fwd"
][
self
.
_offsets
[
"weight"
]
+
i
]
for
i
in
range
(
self
.
num_gemms
)
]
for
i
in
range
(
self
.
num_gemms
):
weight_quantizers
[
i
].
internal
=
True
if
torch
.
is_grad_enabled
():
grad_output_quantizers
=
[
self
.
quantizers
[
"scaling_bwd"
][
self
.
_offsets
[
"input"
]
+
i
]
for
i
in
range
(
self
.
num_gemms
)
]
for
i
in
range
(
self
.
num_gemms
):
grad_output_quantizers
[
i
].
internal
=
True
if
torch
.
is_grad_enabled
():
if
torch
.
is_grad_enabled
():
linear_fn
=
_Batch
ed
Linear
.
apply
linear_fn
=
_BatchLinear
.
apply
args
=
[]
args
=
[]
else
:
else
:
linear_fn
=
_Batch
ed
Linear
.
forward
linear_fn
=
_BatchLinear
.
forward
args
=
[
None
]
args
=
[
None
]
args
+=
(
args
+=
(
inp
,
inp
,
...
@@ -652,22 +588,22 @@ class BatchedLinear(TransformerEngineBaseModule):
...
@@ -652,22 +588,22 @@ class BatchedLinear(TransformerEngineBaseModule):
is_first_microbatch
,
is_first_microbatch
,
self
.
fp8
,
self
.
fp8
,
self
.
fp8_calibration
,
self
.
fp8_calibration
,
input_quantizers
,
self
.
fp8_meta
,
weight_quantizers
,
output_quantizers
,
grad_output_quantizers
,
self
.
fuse_wgrad_accumulation
,
self
.
fuse_wgrad_accumulation
,
is_cpu_offload_enabled
(),
CPUOffloadEnabled
,
self
.
tp_group
,
self
.
tp_size
,
self
.
sequence_parallel
,
self
.
sequence_parallel
,
self
.
tp_size
>
1
,
self
.
activation_dtype
,
self
.
activation_dtype
,
self
.
parallel_mode
,
torch
.
is_grad_enabled
(),
torch
.
is_grad_enabled
(),
self
,
skip_fp8_weight_update
,
*
weight_tensors
,
*
weight_tensors
,
*
weight_tensors_fp8
,
*
bias_tensors
,
*
bias_tensors
,
)
)
out
=
linear_fn
(
*
args
)
out
=
linear_fn
(
*
args
)
if
self
.
gemm_bias_unfused_add
:
if
self
.
gemm_bias_unfused_add
:
out_shape
=
out
.
shape
out_shape
=
out
.
shape
out
=
torch
.
cat
(
out
=
torch
.
cat
(
...
@@ -678,7 +614,6 @@ class BatchedLinear(TransformerEngineBaseModule):
...
@@ -678,7 +614,6 @@ class BatchedLinear(TransformerEngineBaseModule):
)
)
]
]
).
view
(
out_shape
)
).
view
(
out_shape
)
if
self
.
return_bias
:
if
self
.
return_bias
:
return
out
,
[
cast_if_needed
(
b
,
self
.
activation_dtype
)
for
b
in
bias_tensors
]
return
out
,
[
cast_if_needed
(
b
,
self
.
activation_dtype
)
for
b
in
bias_tensors
]
return
out
return
out
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment