Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
7a923605
Commit
7a923605
authored
Aug 21, 2025
by
yuguo
Browse files
[DCU] tensorwise int8 train opt
parent
686e93cd
Changes
9
Show whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
1277 additions
and
272 deletions
+1277
-272
tests/pytorch/test_int8_channelwise_gemm_exact.py
tests/pytorch/test_int8_channelwise_gemm_exact.py
+271
-74
transformer_engine/common/gemm/cublaslt_gemm.cu
transformer_engine/common/gemm/cublaslt_gemm.cu
+76
-1
transformer_engine/common/gemm/rocm_gemm.cu
transformer_engine/common/gemm/rocm_gemm.cu
+458
-3
transformer_engine/common/include/transformer_engine/gemm.h
transformer_engine/common/include/transformer_engine/gemm.h
+5
-0
transformer_engine/common/transformer_engine.cpp
transformer_engine/common/transformer_engine.cpp
+1
-1
transformer_engine/pytorch/cpp_extensions/gemm.py
transformer_engine/pytorch/cpp_extensions/gemm.py
+168
-46
transformer_engine/pytorch/csrc/extensions.h
transformer_engine/pytorch/csrc/extensions.h
+16
-8
transformer_engine/pytorch/csrc/extensions/gemm.cpp
transformer_engine/pytorch/csrc/extensions/gemm.cpp
+268
-132
transformer_engine/pytorch/csrc/extensions/pybind.cpp
transformer_engine/pytorch/csrc/extensions/pybind.cpp
+14
-7
No files found.
tests/pytorch/test_int8_channelwise_gemm_exact.py
View file @
7a923605
...
...
@@ -32,11 +32,56 @@ import os
int8_simulation_fp8_tensorwise
=
bool
(
int
(
os
.
getenv
(
"NVTE_INT8_SIM_FP8_TENSORWISE"
,
"0"
)))
tensorwise_int8_check
=
bool
(
int
(
os
.
getenv
(
"NVTE_INT8_SIM_FP8_TENSORWISE_CHECK"
,
"0"
)))
def
dtype_tols
(
dtype
:
torch
.
dtype
)
->
Dict
[
str
,
float
]:
"""Estimated numerical error for a datatype
Based on tolerances for torch.testing.assert_close.
"""
if
dtype
==
torch
.
float32
:
return
dict
(
rtol
=
1.3e-6
,
atol
=
1e-5
)
if
dtype
==
torch
.
float16
:
return
dict
(
rtol
=
1e-3
,
atol
=
1e-5
)
if
dtype
==
torch
.
bfloat16
:
return
dict
(
rtol
=
1.6e-2
,
atol
=
1e-5
)
raise
ValueError
(
f
"Unsuppored dtype (
{
dtype
}
)"
)
def
assert_allclose
(
l1
:
List
[
torch
.
Tensor
],
l2
:
List
[
torch
.
Tensor
],
atol
:
float
=
None
,
rtol
:
float
=
None
)
->
bool
:
"""Ensures two lists are equal."""
assert
len
(
l1
)
==
len
(
l2
),
"Unequal number of outputs."
for
i
,
(
t1
,
t2
)
in
enumerate
(
zip
(
l1
,
l2
)):
tols
=
dtype_tols
(
t2
.
dtype
)
if
rtol
is
not
None
:
tols
[
"rtol"
]
=
rtol
if
atol
is
not
None
:
tols
[
"atol"
]
=
atol
result
=
torch
.
allclose
(
t1
,
t2
,
**
tols
)
if
not
result
:
diff
=
torch
.
abs
(
t1
-
t2
)
tol
=
tols
[
"atol"
]
+
(
tols
[
"rtol"
]
*
torch
.
abs
(
t2
))
exceed_mask
=
diff
>
tol
if
exceed_mask
.
any
():
indices
=
torch
.
nonzero
(
exceed_mask
,
as_tuple
=
True
)
max_diff
=
diff
[
exceed_mask
].
max
()
max_idx
=
(
diff
[
exceed_mask
]
==
max_diff
).
nonzero
(
as_tuple
=
True
)[
0
][
0
]
max_location
=
[
idx
[
max_idx
].
item
()
for
idx
in
indices
]
msg
=
(
f
"Outputs not close enough in tensor at idx=
{
i
}
. "
f
"Maximum difference at location
{
max_location
}
"
f
"with
{
t1
[
exceed_mask
][
max_idx
].
item
()
}
vs
{
t2
[
exceed_mask
][
max_idx
].
item
()
}
"
f
"(diff
{
max_diff
.
item
()
}
)."
)
raise
AssertionError
(
msg
)
# TN
m
=
4096
k
=
4096
n
=
4096
seed
=
0
n
=
6144
seed
=
4096
torch
.
manual_seed
(
seed
)
torch
.
cuda
.
manual_seed
(
seed
)
device
=
"cuda"
...
...
@@ -235,40 +280,42 @@ transb = False
x_bf16
=
(
torch
.
randn
((
m
,
k
),
device
=
device
)).
to
(
dtype
=
torch
.
bfloat16
)
w_bf16
=
(
torch
.
randn
((
n
,
k
),
device
=
device
)).
to
(
dtype
=
torch
.
bfloat16
)
output
=
(
torch
.
randn
((
m
,
n
),
device
=
device
)).
to
(
dtype
=
torch
.
bfloat16
)
for
i
in
range
(
20
):
bf16_out
=
torch
.
matmul
(
x_bf16
,
w_bf16
.
t
())
bf16_out
=
torch
.
matmul
(
x_bf16
,
w_bf16
.
t
())
print
(
"bf16_out: "
,
bf16_out
)
torch
.
cuda
.
synchronize
()
start
=
time
.
time
()
for
i
in
range
(
20
):
torch
.
cuda
.
synchronize
()
start
=
time
.
time
()
for
i
in
range
(
20
):
bf16_out
=
torch
.
matmul
(
x_bf16
,
w_bf16
.
t
())
torch
.
cuda
.
synchronize
()
end
=
time
.
time
()
torch
.
cuda
.
synchronize
()
end
=
time
.
time
()
# x_int8, x_scales = per_token_quant_int8(x_bf16)
# w_int8, w_scales = per_token_quant_int8(w_bf16)
# print("x_int8: ", x_int8)
# print("w_int8: ", w_int8)
# x_int8, x_scales = per_token_quant_int8(x_bf16)
# w_int8, w_scales = per_token_quant_int8(w_bf16)
# print("x_int8: ", x_int8)
# print("w_int8: ", w_int8)
# Cast to FP8 and back
x_fp8
=
to_float8_CS
(
x_bf16
,
fp8_dtype
=
tex
.
DType
.
kFloat8E4M3
)
w_fp8
=
to_float8_CS
(
w_bf16
,
fp8_dtype
=
tex
.
DType
.
kFloat8E4M3
)
# print("x_fp8: ", x_fp8._data.view(dtype=torch.float8_e4m3fn))
# print("w_fp8: ", w_fp8._data.view(dtype=torch.float8_e4m3fn))
# Cast to FP8 and back
x_fp8
=
to_float8_CS
(
x_bf16
,
fp8_dtype
=
tex
.
DType
.
kFloat8E4M3
)
w_fp8
=
to_float8_CS
(
w_bf16
,
fp8_dtype
=
tex
.
DType
.
kFloat8E4M3
)
# print("x_fp8: ", x_fp8._data.view(dtype=torch.float8_e4m3fn))
# print("w_fp8: ", w_fp8._data.view(dtype=torch.float8_e4m3fn))
if
int8_simulation_fp8_tensorwise
:
if
int8_simulation_fp8_tensorwise
:
x_int8
,
x_scales
=
x_fp8
.
_data
.
view
(
dtype
=
torch
.
int8
),
x_fp8
.
_scale_inv
w_int8
,
w_scales
=
w_fp8
.
_data
.
view
(
dtype
=
torch
.
int8
),
w_fp8
.
_scale_inv
else
:
else
:
x_int8
,
x_scales
=
per_token_quant_fp8_to_int8
(
x_fp8
.
_data
.
view
(
dtype
=
torch
.
float8_e4m3fn
),
x_fp8
.
_scale_inv
,
False
)
w_int8
,
w_scales
=
per_token_quant_fp8_to_int8
(
w_fp8
.
_data
.
view
(
dtype
=
torch
.
float8_e4m3fn
),
w_fp8
.
_scale_inv
,
False
)
# print("x_int8: ", x_int8)
# print("w_int8: ", w_int8)
# print("x_int8: ", x_int8)
# print("w_int8: ", w_int8)
# cuBLAS GEMM
# return type is out, bias_grad, gelu_input, extra_output
# We are just capturing out.
y_int32
=
tex
.
generic_gemm
(
# cuBLAS GEMM
# return type is out, bias_grad, gelu_input, extra_output
# We are just capturing out.
y_int32
=
tex
.
generic_gemm
(
w_int8
,
transa
,
x_int8
,
...
...
@@ -285,19 +332,42 @@ y_int32 = tex.generic_gemm(
workspace
.
shape
[
0
],
accumulate
,
use_split_accumulator
,
)[
0
]
)[
0
]
# y_int32 = torch._int_mm(x_int8, w_int8.t())
# print("y_int32: ", y_int32)
# y_int32 = torch._int_mm(x_int8, w_int8.t())
# print("y_int32: ", y_int32)
if
int8_simulation_fp8_tensorwise
:
if
int8_simulation_fp8_tensorwise
:
tensorwise_dequantize
(
x_scales
,
w_scales
,
y_int32
,
output
)
else
:
else
:
output
=
channelwise_dequantize_transB
(
x_scales
,
w_scales
,
y_int32
)
# print("out_scales.shape: ", out_scales.shape)
# print("out_scales: ", out_scales)
print
(
"bf16_out: "
,
bf16_out
)
print
(
"output: "
,
output
)
print
(
"output: "
,
output
)
if
tensorwise_int8_check
:
lt_output
=
tex
.
generic_gemm
(
w_fp8
,
transa
,
x_fp8
,
transb
,
out
,
out_quantizer
,
TE_DType
[
torch
.
bfloat16
],
bias
,
bias_dtype
,
use_gelu
,
aux_tensor
,
use_grad
,
workspace
,
workspace
.
shape
[
0
],
accumulate
,
True
,
)[
0
]
print
(
"lt_output: "
,
lt_output
)
assert_allclose
([
output
],
[
lt_output
])
# print("out_scales.shape: ", out_scales.shape)
# print("out_scales: ", out_scales)
# torch.cuda.synchronize()
# start = time.time()
...
...
@@ -339,6 +409,7 @@ w_bf16 = (torch.randn((n, k), device=device)).to(dtype=torch.bfloat16)
dx
=
(
torch
.
randn
((
m
,
k
),
device
=
device
)).
to
(
dtype
=
torch
.
bfloat16
)
bf16_dx
=
torch
.
matmul
(
dy_bf16
,
w_bf16
)
print
(
"bf16_dx: "
,
bf16_dx
)
torch
.
cuda
.
synchronize
()
start
=
time
.
time
()
...
...
@@ -397,11 +468,32 @@ if int8_simulation_fp8_tensorwise:
else
:
dx
=
channelwise_dequantize
(
dy_scales
,
w_scales
,
dx_int32
)
# dx = channelwise_dequantize_transB(dy_scales, w_scales, dx_int32)
print
(
"dx: "
,
dx
)
if
tensorwise_int8_check
:
lt_dx
=
tex
.
generic_gemm
(
w_fp8
,
transa
,
dy_fp8
,
transb
,
out
,
out_quantizer
,
TE_DType
[
torch
.
bfloat16
],
bias
,
bias_dtype
,
use_gelu
,
aux_tensor
,
use_grad
,
workspace
,
workspace
.
shape
[
0
],
accumulate
,
True
,
)[
0
]
print
(
"lt_dx: "
,
lt_dx
)
assert_allclose
([
dx
],
[
lt_dx
])
# print("dx_scales.shape: ", dx_scales.shape)
# print("dx_scales: ", dx_scales)
print
(
"bf16_dx: "
,
bf16_dx
)
print
(
"dx: "
,
dx
)
# torch.cuda.synchronize()
# start = time.time()
...
...
@@ -447,11 +539,9 @@ transb = True
dy_bf16
=
(
torch
.
randn
((
m
,
n
),
device
=
device
)).
to
(
dtype
=
torch
.
bfloat16
)
x_bf16
=
(
torch
.
randn
((
m
,
k
),
device
=
device
)).
to
(
dtype
=
torch
.
bfloat16
)
dw
=
(
torch
.
randn
((
n
,
k
),
device
=
device
)).
to
(
dtype
=
torch
.
bfloat16
)
seed
=
0
torch
.
manual_seed
(
seed
)
torch
.
cuda
.
manual_seed
(
seed
)
bf16_dw
=
torch
.
matmul
(
dy_bf16
.
t
(),
x_bf16
)
print
(
"bf16_dw: "
,
bf16_dw
)
torch
.
cuda
.
synchronize
()
start
=
time
.
time
()
...
...
@@ -504,9 +594,30 @@ if int8_simulation_fp8_tensorwise:
else
:
dw
=
channelwise_dequantize_transA
(
dy_scales
,
x_scales
,
dw_int32
)
# dw = channelwise_dequantize_transB(dy_scales, x_scales, dw_int32)
print
(
"bf16_dw: "
,
bf16_dw
)
print
(
"dw: "
,
dw
)
if
tensorwise_int8_check
:
lt_dw
=
tex
.
generic_gemm
(
x_fp8
,
transa
,
dy_fp8
,
transb
,
out
,
out_quantizer
,
TE_DType
[
torch
.
bfloat16
],
bias
,
bias_dtype
,
use_gelu
,
aux_tensor
,
use_grad
,
workspace
,
workspace
.
shape
[
0
],
accumulate
,
True
,
)[
0
]
print
(
"lt_dw: "
,
lt_dw
)
assert_allclose
([
dw
],
[
lt_dw
])
# torch.cuda.synchronize()
# start = time.time()
# for i in range(20):
...
...
@@ -548,9 +659,9 @@ print("dw: ", dw)
# bacth gemm wgrad
m
=
32
k
=
32
n
=
32
m
=
1024
k
=
1024
n
=
1024
b
=
4
transa
=
False
...
...
@@ -558,9 +669,6 @@ transb = True
dy_int8
=
(
torch
.
randn
((
b
,
m
,
n
),
device
=
device
)).
to
(
dtype
=
torch
.
int8
)
x_int8
=
(
torch
.
randn
((
b
,
m
,
k
),
device
=
device
)).
to
(
dtype
=
torch
.
int8
)
seed
=
0
torch
.
manual_seed
(
seed
)
torch
.
cuda
.
manual_seed
(
seed
)
int32_dw_list
=
[]
for
i
in
range
(
b
):
...
...
@@ -597,3 +705,92 @@ te_dw = tex.generic_batchgemm(
# print("te_dw.shape: ", te_dw.view(b, -1, te_dw.size(-1)).shape)
# print("te_dw: ", te_dw.view(b, -1, te_dw.size(-1)))
torch
.
testing
.
assert_close
(
te_dw
.
view
(
b
,
-
1
,
te_dw
.
size
(
-
1
)),
batched_int32_dw
,
atol
=
0
,
rtol
=
0
)
# NT
b
=
4
transa
=
False
transb
=
True
dy_bf16
=
[(
torch
.
randn
((
m
,
n
),
device
=
device
)).
to
(
dtype
=
torch
.
bfloat16
)
for
i
in
range
(
b
)]
x_bf16
=
[(
torch
.
randn
((
m
,
k
),
device
=
device
)).
to
(
dtype
=
torch
.
bfloat16
)
for
i
in
range
(
b
)]
dw_ref
=
[(
torch
.
randn
((
n
,
k
),
device
=
device
)).
to
(
dtype
=
torch
.
bfloat16
)
for
i
in
range
(
b
)]
dw
=
[(
torch
.
randn
((
n
,
k
),
device
=
device
)).
to
(
dtype
=
torch
.
bfloat16
)
for
i
in
range
(
b
)]
# Cast to FP8 and back
dy_fp8
=
[
to_float8_CS
(
dy_bf16
[
i
],
fp8_dtype
=
tex
.
DType
.
kFloat8E5M2
)
for
i
in
range
(
b
)]
x_fp8
=
[
to_float8_CS
(
x_bf16
[
i
],
fp8_dtype
=
tex
.
DType
.
kFloat8E5M2
)
for
i
in
range
(
b
)]
if
int8_simulation_fp8_tensorwise
:
dy_int8
,
dy_scales
=
[
dy_fp8
[
i
].
_data
.
view
(
dtype
=
torch
.
int8
)
for
i
in
range
(
b
)],
[
dy_fp8
[
i
].
_scale_inv
for
i
in
range
(
b
)]
x_int8
,
x_scales
=
[
x_fp8
[
i
].
_data
.
view
(
dtype
=
torch
.
int8
)
for
i
in
range
(
b
)],
[
x_fp8
[
i
].
_scale_inv
for
i
in
range
(
b
)]
else
:
dy_int8
,
dy_scales
=
[],
[]
x_int8
,
x_scales
=
[],
[]
assert
False
for
i
in
range
(
b
):
# cuBLAS GEMM
# return type is out, bias_grad, gelu_input, extra_output
# We are just capturing out.
dw_int32
=
tex
.
generic_gemm
(
x_int8
[
i
],
transa
,
dy_int8
[
i
],
transb
,
None
,
None
,
TE_DType
[
out_dtype
],
bias
,
bias_dtype
,
use_gelu
,
aux_tensor
,
use_grad
,
workspace
,
workspace
.
shape
[
0
],
accumulate
,
use_split_accumulator
,
)[
0
]
if
int8_simulation_fp8_tensorwise
:
tensorwise_dequantize
(
dy_scales
[
i
],
x_scales
[
i
],
dw_int32
,
dw_ref
[
i
])
else
:
assert
False
dw_ref_tensor
=
torch
.
stack
(
dw_ref
).
contiguous
()
# print("dw_ref_tensor: ", dw_ref_tensor)
torch
.
cuda
.
synchronize
()
dy_int8_tensor
=
torch
.
stack
(
dy_int8
).
contiguous
()
dy_scales_tensor
=
torch
.
stack
(
dy_scales
).
contiguous
()
x_int8_tensor
=
torch
.
stack
(
x_int8
).
contiguous
()
x_scales_tensor
=
torch
.
stack
(
x_scales
).
contiguous
()
dw_tensor
=
torch
.
stack
(
dw
).
contiguous
()
out_dtype
=
torch
.
bfloat16
dw_tensor
=
tex
.
tensorwise_int8_batchgemm
(
x_int8_tensor
.
view
(
-
1
,
x_int8
.
size
(
-
1
)),
transa
,
dy_int8_tensor
.
view
(
-
1
,
dy_int8
.
size
(
-
1
)),
transb
,
x_scales_tensor
,
dy_scales_tensor
,
dw_tensor
.
view
(
-
1
,
dw
.
size
(
-
1
)),
b
,
out_quantizer
,
TE_DType
[
out_dtype
],
bias
,
bias_dtype
,
use_gelu
,
aux_tensor
,
use_grad
,
workspace
,
workspace
.
shape
[
0
],
accumulate
,
use_split_accumulator
,
)[
0
]
# print("dw_tensor: ", dw_tensor)
torch
.
testing
.
assert_close
(
dw_ref_tensor
,
dw_tensor
,
atol
=
1e-5
,
rtol
=
1e-5
)
transformer_engine/common/gemm/cublaslt_gemm.cu
View file @
7a923605
...
...
@@ -684,8 +684,9 @@ void nvte_cublas_gemm(const NVTETensor A, const NVTETensor B, NVTETensor D, cons
const
char
*
NVTE_FORCE_ROCM_GEMM
=
std
::
getenv
(
"NVTE_FORCE_ROCM_GEMM"
);
const
bool
use_fp8
=
is_fp8_dtype
(
inputA
->
data
.
dtype
)
||
is_fp8_dtype
(
inputB
->
data
.
dtype
);
const
char
*
NVTE_INT8_SIM_FP8_TENSORWISE
=
std
::
getenv
(
"NVTE_INT8_SIM_FP8_TENSORWISE"
);
if
(
NVTE_INT8_SIM_FP8_TENSORWISE
!=
nullptr
&&
NVTE_INT8_SIM_FP8_TENSORWISE
[
0
]
==
'1'
&&
use_int8
&&
use_split_accumulator
)
nvte_use_hipblaslt
=
1
;
if
((
biasTensor
->
data
.
dptr
!=
nullptr
)
||
(
outputGelu
->
data
.
dptr
!=
nullptr
)
||
(
use_fp8
)
||
(
NVTE_FORCE_ROCM_GEMM
!=
nullptr
&&
NVTE_FORCE_ROCM_GEMM
[
0
]
==
'1'
)
||
(
nvte_use_hipblaslt
)
||
(
nvte_use_rocblas
))
{
NVTE_CHECK
(
!
use_int8
,
"Int8 gemm just surpport pure int8 gemm without any epilogue."
);
cublas_gemm
(
inputA
,
inputB
,
outputD
,
biasTensor
,
outputGelu
,
m
,
n
,
k
,
lda
,
ldb
,
ldd
,
transa
,
transb
,
grad
,
wspace
->
data
.
dptr
,
wspace
->
data
.
shape
[
0
],
accumulate
,
use_split_accumulator
,
math_sm_count
,
0
,
0
,
false
,
nullptr
,
stream
,
nvte_use_hipblaslt
,
nvte_use_rocblas
,
compute_stream_offset
);
...
...
@@ -1100,4 +1101,78 @@ void nvte_cublas_batchgemm_v2(const NVTETensor A, const NVTETensor B, NVTETensor
batch_count
,
stream
);
}
// add for batchgemm
void
nvte_cublas_batchgemm_v3
(
const
NVTETensor
A
,
const
NVTETensor
B
,
const
NVTETensor
A_scales
,
const
NVTETensor
B_scales
,
NVTETensor
D
,
const
NVTETensor
bias
,
NVTETensor
pre_gelu_out
,
bool
transa
,
bool
transb
,
bool
grad
,
NVTETensor
workspace
,
bool
accumulate
,
bool
use_split_accumulator
,
int
math_sm_count
,
int
batch_count
,
cudaStream_t
stream
)
{
NVTE_API_CALL
(
nvte_cublas_batchgemm_v3
);
using
namespace
transformer_engine
;
const
Tensor
*
inputA
=
convertNVTETensorCheck
(
A
);
const
Tensor
*
inputB
=
convertNVTETensorCheck
(
B
);
const
Tensor
*
inputA_scales
=
convertNVTETensorCheck
(
A_scales
);
const
Tensor
*
inputB_scales
=
convertNVTETensorCheck
(
B_scales
);
Tensor
*
outputD
=
convertNVTETensor
(
D
);
const
Tensor
*
biasTensor
=
convertNVTETensor
(
bias
);
Tensor
*
outputGelu
=
convertNVTETensor
(
pre_gelu_out
);
Tensor
*
wspace
=
convertNVTETensor
(
workspace
);
if
((
biasTensor
->
data
.
dptr
!=
nullptr
)
||
(
outputGelu
->
data
.
dptr
!=
nullptr
))
{
NVTE_ERROR
(
"MOE batchgemm not surpport bias or gelu."
);
}
int
m
,
n
,
k
;
if
(
!
transa
&&
transb
)
{
// for NT
m
=
transa
?
inputA
->
data
.
shape
[
0
]
/
batch_count
:
inputA
->
data
.
shape
[
1
];
k
=
transa
?
inputA
->
data
.
shape
[
1
]
:
inputA
->
data
.
shape
[
0
]
/
batch_count
;
n
=
transb
?
inputB
->
data
.
shape
[
1
]
:
inputB
->
data
.
shape
[
0
]
/
batch_count
;
}
else
if
(
transa
&&
!
transb
){
// for TN
m
=
transa
?
inputA
->
data
.
shape
[
0
]
/
batch_count
:
inputA
->
data
.
shape
[
1
];
k
=
transa
?
inputA
->
data
.
shape
[
1
]
:
inputA
->
data
.
shape
[
0
]
/
batch_count
;
n
=
transb
?
inputB
->
data
.
shape
[
1
]
:
inputB
->
data
.
shape
[
0
]
/
batch_count
;
}
else
if
(
!
transa
&&
!
transb
){
// for NN
m
=
transa
?
inputA
->
data
.
shape
[
0
]
/
batch_count
:
inputA
->
data
.
shape
[
1
];
k
=
transa
?
inputA
->
data
.
shape
[
1
]
:
inputA
->
data
.
shape
[
0
]
/
batch_count
;
n
=
transb
?
inputB
->
data
.
shape
[
1
]
:
inputB
->
data
.
shape
[
0
]
/
batch_count
;
}
int
lda
,
ldb
,
ldd
;
if
(
transa
&&
!
transb
)
{
// TN
lda
=
k
;
ldb
=
k
;
ldd
=
m
;
}
else
if
(
!
transa
&&
!
transb
)
{
// NN
lda
=
m
;
ldb
=
k
;
ldd
=
m
;
}
else
if
(
!
transa
&&
transb
)
{
// NT
lda
=
m
;
ldb
=
n
;
ldd
=
m
;
}
else
{
// TT
NVTE_ERROR
(
"TT layout not allowed."
);
}
hipblasLtHandle_t
handle
=
nullptr
;
// Init hipblaslt handles (once, globally)
static
std
::
once_flag
init_flag
;
static
hipblasLtHandle_t
hipblaslt_handles
[
compute_num_streams
];
std
::
call_once
(
init_flag
,
init_hipblaslt_handles
,
hipblaslt_handles
);
handle
=
hipblaslt_handles
[
0
];
hipblaslt_batchgemm_tensorwise_int8
(
inputA
,
inputB
,
inputA_scales
,
inputB_scales
,
outputD
,
biasTensor
,
outputGelu
,
m
,
n
,
k
,
lda
,
ldb
,
ldd
,
(
transa
)
?
HIPBLAS_OP_T
:
HIPBLAS_OP_N
,
(
transb
)
?
HIPBLAS_OP_T
:
HIPBLAS_OP_N
,
grad
,
wspace
->
data
.
dptr
,
wspace
->
data
.
shape
[
0
],
accumulate
,
use_split_accumulator
,
math_sm_count
,
0
,
0
,
false
,
nullptr
,
batch_count
,
stream
,
handle
);
}
#endif
\ No newline at end of file
transformer_engine/common/gemm/rocm_gemm.cu
View file @
7a923605
...
...
@@ -971,14 +971,17 @@ void hipblaslt_gemm(const Tensor* inputA, const Tensor* inputB, Tensor* outputD,
hipblasLtHandle_t
handle
)
{
void
*
A
=
inputA
->
data
.
dptr
;
void
*
A_scale_inverse
=
inputA
->
scale_inv
.
dptr
;
float
*
A_scale_inverse_float
=
(
float
*
)(
inputA
->
scale_inv
.
dptr
);
void
*
B
=
inputB
->
data
.
dptr
;
void
*
B_scale_inverse
=
inputB
->
scale_inv
.
dptr
;
float
*
B_scale_inverse_float
=
(
float
*
)(
inputB
->
scale_inv
.
dptr
);
void
*
D
=
outputD
->
data
.
dptr
;
void
*
bias_ptr
=
inputBias
->
data
.
dptr
;
const
bool
bias
=
bias_ptr
!=
nullptr
;
void
*
pre_gelu_out
=
outputPreGelu
->
data
.
dptr
;
const
bool
gelu
=
pre_gelu_out
!=
nullptr
;
const
bool
use_fp8
=
is_fp8_dtype
(
inputA
->
data
.
dtype
)
||
is_fp8_dtype
(
inputB
->
data
.
dtype
);
const
bool
use_int8
=
is_int8_dtype
(
inputA
->
data
.
dtype
)
||
is_int8_dtype
(
inputB
->
data
.
dtype
);
const
hipDataType
A_type
=
get_hipblaslt_dtype
(
inputA
->
data
.
dtype
);
const
hipDataType
B_type
=
get_hipblaslt_dtype
(
inputB
->
data
.
dtype
);
const
hipDataType
D_type
=
get_hipblaslt_dtype
(
outputD
->
data
.
dtype
);
...
...
@@ -988,11 +991,19 @@ void hipblaslt_gemm(const Tensor* inputA, const Tensor* inputB, Tensor* outputD,
"FP8 input to GEMM requires inverse of scale!"
);
NVTE_CHECK
(
!
is_fp8_dtype
(
inputB
->
data
.
dtype
)
||
B_scale_inverse
!=
nullptr
,
"FP8 input to GEMM requires inverse of scale!"
);
NVTE_CHECK
(
!
is_int8_dtype
(
inputA
->
data
.
dtype
)
||
A_scale_inverse
!=
nullptr
,
"INT8 input to GEMM requires inverse of scale!"
);
NVTE_CHECK
(
!
is_int8_dtype
(
inputB
->
data
.
dtype
)
||
B_scale_inverse
!=
nullptr
,
"INT8 input to GEMM requires inverse of scale!"
);
bool
tensorwise_int8
=
0
;;
const
char
*
NVTE_INT8_SIM_FP8_TENSORWISE
=
std
::
getenv
(
"NVTE_INT8_SIM_FP8_TENSORWISE"
);
if
(
NVTE_INT8_SIM_FP8_TENSORWISE
!=
nullptr
&&
NVTE_INT8_SIM_FP8_TENSORWISE
[
0
]
==
'1'
&&
use_int8
)
tensorwise_int8
=
1
;
// check consistency of arguments:
// if fp8 is desired, context cannot be null
// fp8 + gelu fusion + fp8 aux is unavailable right now.
if
(
use_fp8
)
{
if
(
use_fp8
||
use_int8
)
{
NVTE_CHECK
(
!
gelu
,
"fp8 gemm + gelu fusion is unavailable right now!"
);
}
float
one
=
1.0
;
...
...
@@ -1014,6 +1025,17 @@ void hipblaslt_gemm(const Tensor* inputA, const Tensor* inputB, Tensor* outputD,
hipblasLtMatmulPreference_t
preference
=
nullptr
;
hipblasLtEpilogue_t
epilogue
=
HIPBLASLT_EPILOGUE_DEFAULT
;
hipblasLtMatmulFlags_t
matmul_flag
=
HIPBLASLT_MATMUL_FLAGS_INT8_SCALE_BF16
;
if
(
tensorwise_int8
)
{
if
(
D_type
==
HIP_R_16BF
)
{
matmul_flag
=
HIPBLASLT_MATMUL_FLAGS_INT8_SCALE_BF16
;
}
else
if
(
D_type
==
HIP_R_32F
)
{
matmul_flag
=
HIPBLASLT_MATMUL_FLAGS_INT8_SCALE_FP32
;
}
else
{
NVTE_CHECK
(
false
,
"tensorwise_int8 only surpport D_type bf16 or fp32!"
);
}
}
int64_t
ld_gelumat
=
(
int64_t
)
ldd
;
// default to tf32 except for e5m2 inputs where the config is not supported
...
...
@@ -1026,7 +1048,11 @@ void hipblaslt_gemm(const Tensor* inputA, const Tensor* inputB, Tensor* outputD,
transb
==
HIPBLAS_OP_N
?
n
:
k
,
ldb
));
NVTE_CHECK_HIPBLASLT
(
hipblasLtMatrixLayoutCreate
(
&
Ddesc
,
D_type
,
m
,
n
,
ldd
));
if
(
tensorwise_int8
)
{
NVTE_CHECK_HIPBLASLT
(
hipblasLtMatmulDescCreate
(
&
operationDesc
,
gemm_compute_type
,
HIP_R_32F
,
matmul_flag
));
}
else
{
NVTE_CHECK_HIPBLASLT
(
hipblasLtMatmulDescCreate
(
&
operationDesc
,
gemm_compute_type
,
HIP_R_32F
));
}
NVTE_CHECK_HIPBLASLT
(
hipblasLtMatmulDescSetAttribute
(
operationDesc
,
HIPBLASLT_MATMUL_DESC_TRANSA
,
&
transa
,
sizeof
(
transa
)));
NVTE_CHECK_HIPBLASLT
(
hipblasLtMatmulDescSetAttribute
(
operationDesc
,
HIPBLASLT_MATMUL_DESC_TRANSB
,
...
...
@@ -1055,6 +1081,19 @@ void hipblaslt_gemm(const Tensor* inputA, const Tensor* inputB, Tensor* outputD,
operationDesc
,
HIPBLASLT_MATMUL_DESC_BIAS_DATA_TYPE
,
&
bias_type
,
sizeof
(
bias_type
)));
}
}
if
(
tensorwise_int8
)
{
NVTE_CHECK_HIPBLASLT
(
hipblasLtMatmulDescSetAttribute
(
operationDesc
,
HIPBLASLT_MATMUL_DESC_A_SCALE_POINTER
,
(
void
*
)
&
A_scale_inverse_float
,
sizeof
(
void
*
)));
NVTE_CHECK_HIPBLASLT
(
hipblasLtMatmulDescSetAttribute
(
operationDesc
,
HIPBLASLT_MATMUL_DESC_B_SCALE_POINTER
,
(
void
*
)
&
B_scale_inverse_float
,
sizeof
(
void
*
)));
if
(
bias
)
{
NVTE_CHECK
(
false
,
"tensorwise_int8 not surpport bias!"
);
}
}
if
(
bias
&&
gelu
)
{
if
(
grad
)
{
...
...
@@ -1260,6 +1299,422 @@ void hipblaslt_gemm(const Tensor* inputA, const Tensor* inputB, Tensor* outputD,
NVTE_CHECK_HIPBLASLT
(
hipblasLtMatmulDescDestroy
(
operationDesc
));
}
void
hipblaslt_batchgemm_tensorwise_int8
(
const
Tensor
*
inputA
,
const
Tensor
*
inputB
,
const
Tensor
*
inputA_scales
,
const
Tensor
*
inputB_scales
,
Tensor
*
outputD
,
const
Tensor
*
inputBias
,
Tensor
*
outputPreGelu
,
int
m
,
int
n
,
int
k
,
int
lda
,
int
ldb
,
int
ldd
,
hipblasOperation_t
transa
,
hipblasOperation_t
transb
,
bool
grad
,
void
*
workspace
,
size_t
workspaceSize
,
bool
accumulate
,
bool
use_split_accumulator
,
int
math_sm_count
,
int
m_split
,
int
n_split
,
bool
gemm_producer
,
const
Tensor
*
inputCounter
,
size_t
batch_count
,
hipStream_t
stream
,
hipblasLtHandle_t
handle
)
{
void
*
A
=
inputA
->
data
.
dptr
;
void
*
A_scale_inverse
=
inputA_scales
->
data
.
dptr
;
float
*
A_scale_inverse_float
=
(
float
*
)(
inputA_scales
->
data
.
dptr
);
void
*
B
=
inputB
->
data
.
dptr
;
void
*
B_scale_inverse
=
inputB_scales
->
data
.
dptr
;
float
*
B_scale_inverse_float
=
(
float
*
)(
inputB_scales
->
data
.
dptr
);
void
*
D
=
outputD
->
data
.
dptr
;
void
*
bias_ptr
=
inputBias
->
data
.
dptr
;
const
bool
bias
=
bias_ptr
!=
nullptr
;
void
*
pre_gelu_out
=
outputPreGelu
->
data
.
dptr
;
const
bool
gelu
=
pre_gelu_out
!=
nullptr
;
const
bool
use_fp8
=
is_fp8_dtype
(
inputA
->
data
.
dtype
)
||
is_fp8_dtype
(
inputB
->
data
.
dtype
);
const
bool
use_int8
=
is_int8_dtype
(
inputA
->
data
.
dtype
)
||
is_int8_dtype
(
inputB
->
data
.
dtype
);
const
hipDataType
A_type
=
get_hipblaslt_dtype
(
inputA
->
data
.
dtype
);
const
hipDataType
B_type
=
get_hipblaslt_dtype
(
inputB
->
data
.
dtype
);
const
hipDataType
D_type
=
get_hipblaslt_dtype
(
outputD
->
data
.
dtype
);
const
hipDataType
bias_type
=
get_hipblaslt_dtype
(
inputBias
->
data
.
dtype
);
NVTE_CHECK
(
!
is_fp8_dtype
(
inputA
->
data
.
dtype
)
||
A_scale_inverse
!=
nullptr
,
"FP8 input to GEMM requires inverse of scale!"
);
NVTE_CHECK
(
!
is_fp8_dtype
(
inputB
->
data
.
dtype
)
||
B_scale_inverse
!=
nullptr
,
"FP8 input to GEMM requires inverse of scale!"
);
NVTE_CHECK
(
!
is_int8_dtype
(
inputA
->
data
.
dtype
)
||
A_scale_inverse
!=
nullptr
,
"INT8 input to GEMM requires inverse of scale!"
);
NVTE_CHECK
(
!
is_int8_dtype
(
inputB
->
data
.
dtype
)
||
B_scale_inverse
!=
nullptr
,
"INT8 input to GEMM requires inverse of scale!"
);
bool
tensorwise_int8
=
0
;;
const
char
*
NVTE_INT8_SIM_FP8_TENSORWISE
=
std
::
getenv
(
"NVTE_INT8_SIM_FP8_TENSORWISE"
);
if
(
NVTE_INT8_SIM_FP8_TENSORWISE
!=
nullptr
&&
NVTE_INT8_SIM_FP8_TENSORWISE
[
0
]
==
'1'
&&
use_int8
)
tensorwise_int8
=
1
;
// check consistency of arguments:
// if fp8 is desired, context cannot be null
// fp8 + gelu fusion + fp8 aux is unavailable right now.
if
(
use_fp8
||
use_int8
)
{
NVTE_CHECK
(
!
gelu
,
"fp8 gemm + gelu fusion is unavailable right now!"
);
}
float
one
=
1.0
;
float
zero
=
0.0
;
float
beta
=
(
accumulate
)
?
one
:
zero
;
int
device_id
;
NVTE_CHECK_CUDA
(
hipGetDevice
(
&
device_id
));
if
(
handle
==
nullptr
)
{
handle
=
cached_handles
.
get
(
device_id
);
if
(
handle
==
nullptr
)
{
handle
=
cached_handles
.
obtain
(
device_id
);
}
}
hipblasLtMatmulDesc_t
operationDesc
=
nullptr
;
hipblasLtMatrixLayout_t
Adesc
=
nullptr
,
Bdesc
=
nullptr
,
Cdesc
=
nullptr
,
Ddesc
=
nullptr
;
hipblasLtMatmulPreference_t
preference
=
nullptr
;
hipblasLtEpilogue_t
epilogue
=
HIPBLASLT_EPILOGUE_DEFAULT
;
hipblasLtMatmulFlags_t
matmul_flag
=
HIPBLASLT_MATMUL_FLAGS_INT8_SCALE_BF16
;
if
(
tensorwise_int8
)
{
if
(
D_type
==
HIP_R_16BF
)
{
matmul_flag
=
HIPBLASLT_MATMUL_FLAGS_INT8_SCALE_BF16
;
}
else
if
(
D_type
==
HIP_R_32F
)
{
matmul_flag
=
HIPBLASLT_MATMUL_FLAGS_INT8_SCALE_FP32
;
}
else
{
NVTE_CHECK
(
false
,
"tensorwise_int8 only surpport D_type bf16 or fp32!"
);
}
}
int64_t
ld_gelumat
=
(
int64_t
)
ldd
;
// default to tf32 except for e5m2 inputs where the config is not supported
hipblasComputeType_t
gemm_compute_type
=
HIPBLAS_COMPUTE_32F
;
// Create matrix descriptors. Not setting any extra attributes.
NVTE_CHECK_HIPBLASLT
(
hipblasLtMatrixLayoutCreate
(
&
Adesc
,
A_type
,
transa
==
HIPBLAS_OP_N
?
m
:
k
,
transa
==
HIPBLAS_OP_N
?
k
:
m
,
lda
));
NVTE_CHECK_HIPBLASLT
(
hipblasLtMatrixLayoutCreate
(
&
Bdesc
,
B_type
,
transb
==
HIPBLAS_OP_N
?
k
:
n
,
transb
==
HIPBLAS_OP_N
?
n
:
k
,
ldb
));
NVTE_CHECK_HIPBLASLT
(
hipblasLtMatrixLayoutCreate
(
&
Ddesc
,
D_type
,
m
,
n
,
ldd
));
if
(
tensorwise_int8
)
{
size_t
strideA
=
m
*
k
;
size_t
strideB
=
k
*
n
;
size_t
strideD
=
m
*
n
;
hipblasLtMatrixLayoutSetAttribute
(
Adesc
,
HIPBLASLT_MATRIX_LAYOUT_BATCH_COUNT
,
&
batch_count
,
sizeof
(
int32_t
));
hipblasLtMatrixLayoutSetAttribute
(
Adesc
,
HIPBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET
,
&
strideA
,
sizeof
(
int64_t
));
hipblasLtMatrixLayoutSetAttribute
(
Bdesc
,
HIPBLASLT_MATRIX_LAYOUT_BATCH_COUNT
,
&
batch_count
,
sizeof
(
int32_t
));
hipblasLtMatrixLayoutSetAttribute
(
Bdesc
,
HIPBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET
,
&
strideB
,
sizeof
(
int64_t
));
hipblasLtMatrixLayoutSetAttribute
(
Ddesc
,
HIPBLASLT_MATRIX_LAYOUT_BATCH_COUNT
,
&
batch_count
,
sizeof
(
int32_t
));
hipblasLtMatrixLayoutSetAttribute
(
Ddesc
,
HIPBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET
,
&
strideD
,
sizeof
(
int64_t
));
NVTE_CHECK_HIPBLASLT
(
hipblasLtMatmulDescCreate
(
&
operationDesc
,
gemm_compute_type
,
HIP_R_32F
,
matmul_flag
));
}
else
{
NVTE_CHECK_HIPBLASLT
(
hipblasLtMatmulDescCreate
(
&
operationDesc
,
gemm_compute_type
,
HIP_R_32F
));
}
NVTE_CHECK_HIPBLASLT
(
hipblasLtMatmulDescSetAttribute
(
operationDesc
,
HIPBLASLT_MATMUL_DESC_TRANSA
,
&
transa
,
sizeof
(
transa
)));
NVTE_CHECK_HIPBLASLT
(
hipblasLtMatmulDescSetAttribute
(
operationDesc
,
HIPBLASLT_MATMUL_DESC_TRANSB
,
&
transb
,
sizeof
(
transb
)));
// set fp8 attributes -- input and output types should already be set to fp8 as appropriate
// Note: gelu fusion isn't available right now, and we don't need
// amax(D) either (next op is high precision).
if
(
use_fp8
)
{
// Split accumulator.
const
int8_t
fastAccuMode
=
(
use_split_accumulator
)
?
0
:
1
;
/*
NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescSetAttribute(operationDesc,
HIPBLASLT_MATMUL_DESC_FAST_ACCUM, //TODO: We don't have fast accum mode yet
&fastAccuMode,
sizeof(fastAccuMode)));
*/
NVTE_CHECK_HIPBLASLT
(
hipblasLtMatmulDescSetAttribute
(
operationDesc
,
HIPBLASLT_MATMUL_DESC_A_SCALE_POINTER
,
&
A_scale_inverse
,
sizeof
(
A_scale_inverse
)));
NVTE_CHECK_HIPBLASLT
(
hipblasLtMatmulDescSetAttribute
(
operationDesc
,
HIPBLASLT_MATMUL_DESC_B_SCALE_POINTER
,
&
B_scale_inverse
,
sizeof
(
B_scale_inverse
)));
if
(
bias
)
{
NVTE_CHECK_HIPBLASLT
(
hipblasLtMatmulDescSetAttribute
(
operationDesc
,
HIPBLASLT_MATMUL_DESC_BIAS_DATA_TYPE
,
&
bias_type
,
sizeof
(
bias_type
)));
}
}
if
(
tensorwise_int8
)
{
NVTE_CHECK_HIPBLASLT
(
hipblasLtMatmulDescSetAttribute
(
operationDesc
,
HIPBLASLT_MATMUL_DESC_A_SCALE_POINTER
,
(
void
*
)
&
A_scale_inverse_float
,
sizeof
(
void
*
)));
NVTE_CHECK_HIPBLASLT
(
hipblasLtMatmulDescSetAttribute
(
operationDesc
,
HIPBLASLT_MATMUL_DESC_B_SCALE_POINTER
,
(
void
*
)
&
B_scale_inverse_float
,
sizeof
(
void
*
)));
if
(
bias
)
{
NVTE_CHECK
(
false
,
"tensorwise_int8 not surpport bias!"
);
}
}
if
(
bias
&&
gelu
)
{
if
(
grad
)
{
epilogue
=
HIPBLASLT_EPILOGUE_DGELU_BGRAD
;
}
else
{
epilogue
=
HIPBLASLT_EPILOGUE_GELU_AUX_BIAS
;
}
NVTE_CHECK_HIPBLASLT
(
hipblasLtMatmulDescSetAttribute
(
operationDesc
,
HIPBLASLT_MATMUL_DESC_BIAS_POINTER
,
&
bias_ptr
,
sizeof
(
bias_ptr
)));
NVTE_CHECK_HIPBLASLT
(
hipblasLtMatmulDescSetAttribute
(
operationDesc
,
HIPBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER
,
&
pre_gelu_out
,
sizeof
(
pre_gelu_out
)));
NVTE_CHECK_HIPBLASLT
(
hipblasLtMatmulDescSetAttribute
(
operationDesc
,
HIPBLASLT_MATMUL_DESC_EPILOGUE_AUX_LD
,
&
ld_gelumat
,
sizeof
(
ld_gelumat
)));
}
else
if
(
bias
)
{
if
(
grad
)
{
// grad output is always input B
epilogue
=
HIPBLASLT_EPILOGUE_BGRADB
;
}
else
{
epilogue
=
HIPBLASLT_EPILOGUE_BIAS
;
}
NVTE_CHECK_HIPBLASLT
(
hipblasLtMatmulDescSetAttribute
(
operationDesc
,
HIPBLASLT_MATMUL_DESC_BIAS_POINTER
,
&
bias_ptr
,
sizeof
(
bias_ptr
)));
}
else
if
(
gelu
)
{
if
(
grad
)
{
epilogue
=
HIPBLASLT_EPILOGUE_DGELU
;
}
else
{
epilogue
=
HIPBLASLT_EPILOGUE_GELU_AUX
;
}
NVTE_CHECK_HIPBLASLT
(
hipblasLtMatmulDescSetAttribute
(
operationDesc
,
HIPBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER
,
&
pre_gelu_out
,
sizeof
(
pre_gelu_out
)));
NVTE_CHECK_HIPBLASLT
(
hipblasLtMatmulDescSetAttribute
(
operationDesc
,
HIPBLASLT_MATMUL_DESC_EPILOGUE_AUX_LD
,
&
ld_gelumat
,
sizeof
(
ld_gelumat
)));
}
NVTE_CHECK_HIPBLASLT
(
hipblasLtMatmulDescSetAttribute
(
operationDesc
,
HIPBLASLT_MATMUL_DESC_EPILOGUE
,
&
epilogue
,
sizeof
(
epilogue
)));
GemmAlgoCache
::
Key
gemm_cfg
(
algoCache
.
device_cap
(
device_id
),
A_type
,
B_type
,
D_type
,
use_fp8
?
bias_type
:
(
hipDataType
)
-
1
,
m
,
n
,
k
,
lda
,
ldb
,
ldd
,
transa
,
transb
,
epilogue
);
GemmAlgoCache
::
Algo
cached_algo
;
if
(
algoCache
.
find
(
gemm_cfg
,
workspaceSize
,
cached_algo
)
==
0
||
!
cached_algo
.
algo
.
has_value
())
{
int
firstAlgo
=
getIntEnv
(
"TE_HIPBLASLT_ALGO_SELECTION"
,
0
,
0
);
int
tuneLoopCount
=
getIntEnv
(
"TE_HIPBLASLT_TUNING_RUN_COUNT"
,
0
,
0
);
int
algoTuneCount
=
1
;
std
::
vector
<
hipblasLtMatmulHeuristicResult_t
>
algoArr
;
bool
logTuning
=
getIntEnv
(
"TE_HIPBLASLT_LOG_TUNING"
,
0
,
0
)
!=
0
;
if
(
tuneLoopCount
)
{
/* HIPBLASLT may return hundreds of algos for some configs
* Limit amount by default. User may override with env
*/
static
const
int
defaultAlgoCount
=
16
;
algoTuneCount
=
getIntEnv
(
"TE_HIPBLASLT_TUNING_ALGO_COUNT"
,
defaultAlgoCount
,
1
);
}
algoTuneCount
+=
firstAlgo
;
int
algoTotalCount
=
cached_algo
.
hasId
()
?
std
::
max
(
algoTuneCount
,
(
cached_algo
.
index
+
1
))
:
algoTuneCount
;
algoArr
.
resize
(
algoTotalCount
);
NVTE_CHECK_HIPBLASLT
(
hipblasLtMatmulPreferenceCreate
(
&
preference
));
NVTE_CHECK_HIPBLASLT
(
hipblasLtMatmulPreferenceSetAttribute
(
preference
,
HIPBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES
,
&
workspaceSize
,
sizeof
(
workspaceSize
)));
NVTE_CHECK_HIPBLASLT
(
hipblasLtMatmulAlgoGetHeuristic
(
handle
,
operationDesc
,
Adesc
,
Bdesc
,
Ddesc
,
Ddesc
,
preference
,
algoTotalCount
,
algoArr
.
data
(),
&
algoTotalCount
));
algoArr
.
resize
(
algoTotalCount
);
NVTE_CHECK_HIPBLASLT
(
hipblasLtMatmulPreferenceDestroy
(
preference
));
//If cached algo exists in persistent storage we just need to find matching hipblasLtMatmulAlgo_t
if
(
cached_algo
.
hasId
())
{
int
idx
=
(
cached_algo
.
index
<
algoTotalCount
)
?
cached_algo
.
index
:
0
;
for
(
int
i
=
0
;
i
<
algoTotalCount
;
i
++
)
{
const
auto
&
algo
=
algoArr
[
idx
];
if
(
algo
.
state
==
HIPBLAS_STATUS_SUCCESS
)
{
if
(
cached_algo
.
algoId
==
cached_algo
.
getAlgoId
(
algo
.
algo
))
{
cached_algo
.
algo
=
algo
.
algo
;
if
(
algo
.
workspaceSize
!=
cached_algo
.
ws_size_min
||
idx
!=
cached_algo
.
index
)
{
cached_algo
.
ws_size_min
=
algo
.
workspaceSize
;
cached_algo
.
index
=
idx
;
algoCache
.
store
(
gemm_cfg
,
cached_algo
);
}
break
;
}
}
idx
=
(
idx
+
1
)
%
algoTotalCount
;
}
if
(
logTuning
&&
!
cached_algo
.
algo
.
has_value
())
{
std
::
cout
<<
"[WARNING] Cannot find cached algoId "
<<
cached_algo
.
algoId
<<
" in hipBLASLt results"
<<
std
::
endl
;
}
}
//No suitable entry in autotune cache or could not find matched algo in hipBLASLt results
if
(
!
cached_algo
.
algo
.
has_value
())
{
int
bestAlgo
=
-
1
;
algoTuneCount
=
std
::
min
(
algoTuneCount
,
algoTotalCount
);
if
(
tuneLoopCount
>
0
)
{
if
(
logTuning
)
std
::
cout
<<
"[INFO] Perform hipBLASLt algo selection on GPU"
<<
device_id
<<
" in range ["
<<
firstAlgo
<<
"-"
<<
(
algoTuneCount
-
1
)
<<
"] with "
<<
tuneLoopCount
<<
" loops "
<<
std
::
endl
;
NVTE_CHECK_CUDA
(
hipStreamSynchronize
(
stream
));
hipStream_t
profilingStream
;
NVTE_CHECK_CUDA
(
hipStreamCreateWithFlags
(
&
profilingStream
,
hipStreamNonBlocking
));
using
tuning_clock
=
std
::
chrono
::
steady_clock
;
tuning_clock
::
now
();
//the first call takes little longer so do it outside the loop
tuning_clock
::
duration
bestTime
=
tuning_clock
::
duration
::
max
();
for
(
int
algo
=
firstAlgo
;
algo
<
algoTuneCount
;
algo
++
)
{
if
(
algoArr
[
algo
].
state
!=
HIPBLAS_STATUS_SUCCESS
)
{
continue
;
}
// Warm-up call
NVTE_CHECK_HIPBLASLT
(
hipblasLtMatmul
(
handle
,
operationDesc
,
static_cast
<
const
void
*>
(
&
one
),
/* alpha */
A
,
/* A */
Adesc
,
B
,
/* B */
Bdesc
,
static_cast
<
const
void
*>
(
&
beta
),
/* beta */
D
,
/* C */
Ddesc
,
D
,
/* D */
Ddesc
,
&
algoArr
[
algo
].
algo
,
/* algo */
workspace
,
/* workspace */
workspaceSize
,
profilingStream
));
/* stream */
NVTE_CHECK_CUDA
(
hipStreamSynchronize
(
profilingStream
));
//Profiling loop
tuning_clock
::
time_point
startTime
=
tuning_clock
::
now
();
for
(
int
loop
=
0
;
loop
<
tuneLoopCount
;
loop
++
)
{
NVTE_CHECK_HIPBLASLT
(
hipblasLtMatmul
(
handle
,
operationDesc
,
static_cast
<
const
void
*>
(
&
one
),
/* alpha */
A
,
/* A */
Adesc
,
B
,
/* B */
Bdesc
,
static_cast
<
const
void
*>
(
&
beta
),
/* beta */
D
,
/* C */
Ddesc
,
D
,
/* D */
Ddesc
,
&
algoArr
[
algo
].
algo
,
/* algo */
workspace
,
/* workspace */
workspaceSize
,
profilingStream
));
/* stream */
}
NVTE_CHECK_CUDA
(
hipStreamSynchronize
(
profilingStream
));
tuning_clock
::
duration
algoTime
=
tuning_clock
::
now
()
-
startTime
;
if
(
algoTime
<
bestTime
)
{
bestAlgo
=
algo
;
bestTime
=
algoTime
;
}
}
NVTE_CHECK_CUDA
(
hipStreamDestroy
(
profilingStream
));
if
(
bestAlgo
>=
0
)
{
if
(
logTuning
)
std
::
cout
<<
"[INFO] Select hipBLASLt algo "
<<
bestAlgo
<<
" with time "
<<
std
::
chrono
::
duration_cast
<
std
::
chrono
::
nanoseconds
>
(
bestTime
).
count
()
/
tuneLoopCount
<<
" ns"
<<
std
::
endl
;
}
}
else
if
(
firstAlgo
<
algoTuneCount
)
{
bestAlgo
=
firstAlgo
;
}
if
(
bestAlgo
<
0
)
{
NVTE_CHECK_HIPBLASLT
(
hipblasLtMatrixLayoutDestroy
(
Ddesc
));
NVTE_CHECK_HIPBLASLT
(
hipblasLtMatrixLayoutDestroy
(
Bdesc
));
NVTE_CHECK_HIPBLASLT
(
hipblasLtMatrixLayoutDestroy
(
Adesc
));
NVTE_CHECK_HIPBLASLT
(
hipblasLtMatmulDescDestroy
(
operationDesc
));
throw
std
::
runtime_error
(
"Unable to find any suitable algorithms"
);
}
cached_algo
.
algo
=
algoArr
[
bestAlgo
].
algo
;
cached_algo
.
index
=
bestAlgo
;
cached_algo
.
algoId
=
cached_algo
.
getAlgoId
(
algoArr
[
bestAlgo
].
algo
);
cached_algo
.
ws_size_min
=
algoArr
[
bestAlgo
].
workspaceSize
;
cached_algo
.
ws_size_max
=
workspaceSize
;
if
(
logTuning
)
std
::
cout
<<
"[INFO] Use hipBLASLt algo ["
<<
bestAlgo
<<
"] "
<<
cached_algo
.
algoId
<<
std
::
endl
;
algoCache
.
store
(
gemm_cfg
,
cached_algo
);
}
}
// D = alpha * (A * B) + beta * C
NVTE_CHECK_HIPBLASLT
(
hipblasLtMatmul
(
handle
,
operationDesc
,
static_cast
<
const
void
*>
(
&
one
),
/* alpha */
A
,
/* A */
Adesc
,
B
,
/* B */
Bdesc
,
static_cast
<
const
void
*>
(
&
beta
),
/* beta */
D
,
/* C */
Ddesc
,
D
,
/* D */
Ddesc
,
&
cached_algo
.
algo
.
value
(),
/* algo */
workspace
,
/* workspace */
workspaceSize
,
stream
));
/* stream */
NVTE_CHECK_HIPBLASLT
(
hipblasLtMatrixLayoutDestroy
(
Ddesc
));
NVTE_CHECK_HIPBLASLT
(
hipblasLtMatrixLayoutDestroy
(
Bdesc
));
NVTE_CHECK_HIPBLASLT
(
hipblasLtMatrixLayoutDestroy
(
Adesc
));
NVTE_CHECK_HIPBLASLT
(
hipblasLtMatmulDescDestroy
(
operationDesc
));
}
class
userArgsManager
{
public:
userArgsManager
()
{}
...
...
@@ -1357,7 +1812,7 @@ void hipblaslt_goupedgemm(std::vector<const Tensor*>& inputA, std::vector<const
if
(
compute_stream_offset
!=
-
1
)
{
// Init hipblaslt handles (once, globally)
static
std
::
once_flag
init_flag
;
static
hipblasLtHandle_t
hipblaslt_handles
[
1
];
static
hipblasLtHandle_t
hipblaslt_handles
[
compute_num_streams
];
std
::
call_once
(
init_flag
,
init_hipblaslt_handles
,
hipblaslt_handles
);
handle
=
hipblaslt_handles
[
compute_stream_offset
];
...
...
transformer_engine/common/include/transformer_engine/gemm.h
View file @
7a923605
...
...
@@ -134,6 +134,11 @@ void nvte_cublas_batchgemm_v2(const NVTETensor A, const NVTETensor B, NVTETensor
NVTETensor
pre_gelu_out
,
bool
transa
,
bool
transb
,
bool
grad
,
NVTETensor
workspace
,
bool
accumulate
,
bool
use_split_accumulator
,
int
math_sm_count
,
int
batch_count
,
cudaStream_t
stream
);
void
nvte_cublas_batchgemm_v3
(
const
NVTETensor
A
,
const
NVTETensor
B
,
const
NVTETensor
A_scales
,
const
NVTETensor
B_scales
,
NVTETensor
D
,
const
NVTETensor
bias
,
NVTETensor
pre_gelu_out
,
bool
transa
,
bool
transb
,
bool
grad
,
NVTETensor
workspace
,
bool
accumulate
,
bool
use_split_accumulator
,
int
math_sm_count
,
int
batch_count
,
cudaStream_t
stream
);
#endif
#ifdef __cplusplus
...
...
transformer_engine/common/transformer_engine.cpp
View file @
7a923605
...
...
@@ -131,7 +131,7 @@ void CheckScaleTensorShape(const Tensor &t, const std::string &name) {
void
CheckInputTensor
(
const
Tensor
&
t
,
const
std
::
string
&
name
)
{
const
DType
type
=
t
.
dtype
();
if
(
is_fp8_dtype
(
type
))
{
if
(
is_fp8_dtype
(
type
)
||
is_int8_dtype
(
type
)
)
{
// FP8 input needs to have scale_inv
if
(
t
.
has_data
())
{
NVTE_CHECK
(
t
.
scale_inv
.
dptr
!=
nullptr
,
"FP8 scaling factor input "
,
name
,
...
...
transformer_engine/pytorch/cpp_extensions/gemm.py
View file @
7a923605
...
...
@@ -38,6 +38,7 @@ from transformer_engine.pytorch.triton.per_token_group_quant import (per_token_q
from
transformer_engine.pytorch.utils
import
get_device_compute_capability
from
transformer_engine.pytorch.fp8
import
int8_simulation_fp8
,
int8_simulation_fp8_tensorwise
tensorwise_int8_check
=
bool
(
int
(
os
.
getenv
(
"NVTE_INT8_SIM_FP8_TENSORWISE_CHECK"
,
"0"
)))
__all__
=
[
"general_gemm"
,
"general_grouped_gemm"
,
...
...
@@ -181,6 +182,47 @@ def general_gemm(
):
raise
RuntimeError
(
"GEMM with Float8BlockwiseQTensor requires GEMM_READY format"
)
if
int8_simulation_fp8
and
(
isinstance
(
A
,
Float8TensorBase
)
or
isinstance
(
B
,
Float8TensorBase
))
and
int8_simulation_fp8_tensorwise
:
assert
not
gelu
,
"GELU not supported with int8 simulation"
assert
gelu_in
is
None
,
"GELU input not supported with int8 simulation"
assert
bias
is
None
,
"Bias not supported with int8 simulation"
assert
ub
is
None
,
"User buffer not supported with int8 simulation"
assert
ub_type
is
None
,
"User buffer type not supported with int8 simulation"
assert
extra_output
is
None
,
"Extra output not supported with int8 simulation"
assert
not
bulk_overlap
,
"Bulk overlap not supported with int8 simulation"
assert
out_dtype
is
torch
.
bfloat16
or
out_dtype
is
torch
.
float32
,
"Out_dtype must be bfloat16 or float32 for int8 simulation"
args
=
(
A
,
transa
,
# transa
B
,
transb
,
# transb
out
,
quantization_params
,
TE_DType
[
out_dtype
]
if
out_dtype
is
not
None
else
None
,
bias
,
bias_dtype
,
gelu
,
gelu_in
,
grad
,
# grad
workspace
,
workspace
.
shape
[
0
],
accumulate
,
True
,
)
kwargs
=
{
"comm_overlap"
:
ub
,
"comm_type"
:
ub_type
,
"extra_output"
:
extra_output
,
"bulk_overlap"
:
bulk_overlap
,
}
out
,
bias_grad
,
gelu_input
,
extra_output
=
tex
.
generic_gemm
(
*
args
,
**
kwargs
)
if
debug_quantizer
is
not
None
:
out
=
debug_quantizer
.
process_gemm_output
(
out
)
return
out
,
bias_grad
,
gelu_input
,
extra_output
if
int8_simulation_fp8
and
(
isinstance
(
A
,
Float8TensorBase
)
or
isinstance
(
B
,
Float8TensorBase
)):
assert
not
gelu
,
"GELU not supported with int8 simulation"
...
...
@@ -195,10 +237,6 @@ def general_gemm(
if
layout
==
"TN"
:
assert
out_dtype
is
torch
.
bfloat16
out_shape
=
B
.
_data
.
shape
[:
-
1
]
+
(
A
.
_data
.
shape
[
0
],
)
if
int8_simulation_fp8_tensorwise
:
x_int8
,
x_scales
=
B
.
_data
.
view
(
dtype
=
TE_DType_To_Torch
[
B
.
_fp8_dtype
]),
B
.
_scale_inv
w_int8
,
w_scales
=
A
.
_data
.
view
(
dtype
=
TE_DType_To_Torch
[
A
.
_fp8_dtype
]),
A
.
_scale_inv
else
:
x_int8
,
x_scales
=
per_token_quant_fp8_to_int8
(
B
.
_data
.
view
(
dtype
=
TE_DType_To_Torch
[
B
.
_fp8_dtype
]),
B
.
_scale_inv
,
False
)
w_int8
,
w_scales
=
per_token_quant_fp8_to_int8
(
A
.
_data
.
view
(
dtype
=
TE_DType_To_Torch
[
A
.
_fp8_dtype
]),
A
.
_scale_inv
,
False
)
...
...
@@ -220,20 +258,12 @@ def general_gemm(
False
,
use_split_accumulator
,
)[
0
]
if
int8_simulation_fp8_tensorwise
:
y
=
torch
.
empty_like
(
y_int32
,
device
=
y_int32
.
device
,
dtype
=
torch
.
bfloat16
)
tensorwise_dequantize
(
x_scales
,
w_scales
,
y_int32
,
y
)
else
:
y
=
channelwise_dequantize_transB
(
x_scales
,
w_scales
,
y_int32
)
return
y
.
view
(
out_shape
),
None
,
None
,
None
elif
layout
==
"NN"
:
assert
out_dtype
is
torch
.
bfloat16
dx_shape
=
B
.
_data
.
shape
[:
-
1
]
+
(
A
.
_data
.
shape
[
-
1
],
)
if
int8_simulation_fp8_tensorwise
:
dy_int8
,
dy_scales
=
B
.
_data
.
view
(
dtype
=
TE_DType_To_Torch
[
B
.
_fp8_dtype
]),
B
.
_scale_inv
w_int8
,
w_scales
=
A
.
_data
.
view
(
dtype
=
TE_DType_To_Torch
[
A
.
_fp8_dtype
]),
A
.
_scale_inv
else
:
dy_int8
,
dy_scales
=
per_token_quant_fp8_to_int8
(
B
.
_data
.
view
(
dtype
=
TE_DType_To_Torch
[
B
.
_fp8_dtype
]),
B
.
_scale_inv
,
False
)
w_int8
,
w_scales
=
per_token_quant_fp8_to_int8_opt
(
A
.
_data
.
view
(
dtype
=
TE_DType_To_Torch
[
A
.
_fp8_dtype
]),
A
.
_scale_inv
,
False
)
...
...
@@ -255,19 +285,11 @@ def general_gemm(
False
,
use_split_accumulator
,
)[
0
]
if
int8_simulation_fp8_tensorwise
:
dx
=
torch
.
empty_like
(
dx_int32
,
device
=
dx_int32
.
device
,
dtype
=
torch
.
bfloat16
)
tensorwise_dequantize
(
dy_scales
,
w_scales
,
dx_int32
,
dx
)
else
:
dx
=
channelwise_dequantize
(
dy_scales
,
w_scales
,
dx_int32
)
return
dx
.
view
(
dx_shape
),
None
,
None
,
None
elif
layout
==
"NT"
:
assert
out_dtype
is
torch
.
bfloat16
or
out_dtype
is
torch
.
float32
if
int8_simulation_fp8_tensorwise
:
dy_int8
,
dy_scales
=
B
.
_data
.
view
(
dtype
=
TE_DType_To_Torch
[
B
.
_fp8_dtype
]),
B
.
_scale_inv
x_int8
,
x_scales
=
A
.
_data
.
view
(
dtype
=
TE_DType_To_Torch
[
A
.
_fp8_dtype
]),
A
.
_scale_inv
else
:
dy_int8
,
dy_scales
=
per_token_quant_fp8_to_int8_opt
(
B
.
_data
.
view
(
dtype
=
TE_DType_To_Torch
[
B
.
_fp8_dtype
]),
B
.
_scale_inv
,
False
)
x_int8
,
x_scales
=
per_token_quant_fp8_to_int8_opt
(
A
.
_data
.
view
(
dtype
=
TE_DType_To_Torch
[
A
.
_fp8_dtype
]),
A
.
_scale_inv
,
False
)
...
...
@@ -291,27 +313,13 @@ def general_gemm(
)[
0
]
if
out_dtype
is
torch
.
bfloat16
:
if
accumulate
:
if
int8_simulation_fp8_tensorwise
:
tensorwise_dequantize_add
(
dy_scales
,
x_scales
,
dw_int32
,
out
)
else
:
channelwise_dequantize_transA_add
(
dy_scales
,
x_scales
,
dw_int32
,
out
)
else
:
if
int8_simulation_fp8_tensorwise
:
out
=
torch
.
empty_like
(
dw_int32
,
device
=
dw_int32
.
device
,
dtype
=
torch
.
bfloat16
)
tensorwise_dequantize
(
dy_scales
,
x_scales
,
dw_int32
,
out
)
else
:
out
=
channelwise_dequantize_transA
(
dy_scales
,
x_scales
,
dw_int32
)
else
:
if
accumulate
:
if
int8_simulation_fp8_tensorwise
:
tensorwise_dequantize_float_add
(
dy_scales
,
x_scales
,
dw_int32
,
out
)
else
:
channelwise_dequantize_transA_float_add
(
dy_scales
,
x_scales
,
dw_int32
,
out
)
else
:
if
int8_simulation_fp8_tensorwise
:
out
=
torch
.
empty_like
(
dw_int32
,
device
=
dw_int32
.
device
,
dtype
=
torch
.
float32
)
tensorwise_dequantize_float
(
dy_scales
,
x_scales
,
dw_int32
,
out
)
else
:
out
=
channelwise_dequantize_transA_float
(
dy_scales
,
x_scales
,
dw_int32
)
return
out
,
None
,
None
,
None
...
...
@@ -465,6 +473,120 @@ def general_grouped_gemm(
else
:
raise
ValueError
(
f
"Unsupported layout
{
layout
}
in int8 simulation fp8"
)
if
int8_simulation_fp8
and
(
isinstance
(
A
,
Float8TensorBase
)
or
isinstance
(
B
,
Float8TensorBase
))
and
int8_simulation_fp8_tensorwise
:
assert
len
(
set
(
m_splits
))
==
1
,
"Int8 simulation groupgemm just surpport token pad as same as batchgemm for now."
assert
not
gelu
,
"GELU not supported with int8 simulation groupgemm."
assert
not
use_bias
,
"Bias not supported with int8 simulation groupgemm."
assert
out_dtype
is
torch
.
bfloat16
or
out_dtype
is
torch
.
float32
,
"Out_dtype must be bfloat16 or float32 for int8 simulation"
if
layout
==
"TN"
:
assert
out_dtype
is
torch
.
bfloat16
qx_data_list
,
scales_x_list
=
[
b
.
_data
.
view
(
dtype
=
TE_DType_To_Torch
[
b
.
_fp8_dtype
])
for
b
in
B
],
[
b
.
_scale_inv
for
b
in
B
]
w_data_list
,
scales_w_list
=
[
a
.
_data
.
view
(
dtype
=
TE_DType_To_Torch
[
a
.
_fp8_dtype
])
for
a
in
A
],
[
a
.
_scale_inv
for
a
in
A
]
num_gemms
=
len
(
A
)
qx_data
=
torch
.
stack
(
qx_data_list
).
contiguous
()
w_data
=
torch
.
stack
(
w_data_list
).
contiguous
()
scales_x
=
torch
.
stack
(
scales_x_list
).
contiguous
()
scales_w
=
torch
.
stack
(
scales_w_list
).
contiguous
()
out
[
0
]
=
tex
.
tensorwise_int8_batchgemm
(
w_data
.
view
(
-
1
,
w_data
.
size
(
-
1
)),
transa
,
qx_data
.
view
(
-
1
,
qx_data
.
size
(
-
1
)),
transb
,
scales_w
,
scales_x
,
out
[
0
],
num_gemms
,
None
,
TE_DType
[
out_dtype
],
bias
,
bias_dtype
,
gelu
,
gelu_input
[
0
],
grad
,
# grad
workspaces
[
0
],
workspaces
[
0
].
shape
[
0
],
accumulate
,
use_split_accumulator
,
)[
0
]
return
out
,
bias
,
gelu_input
if
layout
==
"NN"
:
assert
out_dtype
is
torch
.
bfloat16
qdout_data_list
,
scales_dout_list
=
[
b
.
_data
.
view
(
dtype
=
TE_DType_To_Torch
[
b
.
_fp8_dtype
])
for
b
in
B
],
[
b
.
_scale_inv
for
b
in
B
]
w_data_list
,
scales_w_list
=
[
a
.
_data
.
view
(
dtype
=
TE_DType_To_Torch
[
a
.
_fp8_dtype
])
for
a
in
A
],
[
a
.
_scale_inv
for
a
in
A
]
num_gemms
=
len
(
A
)
qdout_data
=
torch
.
stack
(
qdout_data_list
).
contiguous
()
w_data
=
torch
.
stack
(
w_data_list
).
contiguous
()
scales_dout
=
torch
.
stack
(
scales_dout_list
).
contiguous
()
scales_w
=
torch
.
stack
(
scales_w_list
).
contiguous
()
out
[
0
]
=
tex
.
tensorwise_int8_batchgemm
(
w_data
.
view
(
-
1
,
w_data
.
size
(
-
1
)),
transa
,
qdout_data
.
view
(
-
1
,
qdout_data
.
size
(
-
1
)),
transb
,
scales_w
,
scales_dout
,
out
[
0
],
num_gemms
,
None
,
TE_DType
[
out_dtype
],
bias
,
bias_dtype
,
gelu
,
gelu_input
[
0
],
grad
,
# grad
workspaces
[
0
],
workspaces
[
0
].
shape
[
0
],
accumulate
,
use_split_accumulator
,
)[
0
]
return
out
,
bias
,
gelu_input
elif
layout
==
"NT"
:
assert
out_dtype
is
torch
.
bfloat16
or
out_dtype
is
torch
.
float32
qdout_data_list
,
scales_dout_list
=
[
b
.
_data
.
view
(
dtype
=
TE_DType_To_Torch
[
b
.
_fp8_dtype
])
for
b
in
B
],
[
b
.
_scale_inv
for
b
in
B
]
qx_data_list
,
scales_x_list
=
[
a
.
_data
.
view
(
dtype
=
TE_DType_To_Torch
[
a
.
_fp8_dtype
])
for
a
in
A
],
[
a
.
_scale_inv
for
a
in
A
]
num_gemms
=
len
(
A
)
qdout_data
=
torch
.
stack
(
qdout_data_list
).
contiguous
()
qx_data
=
torch
.
stack
(
qx_data_list
).
contiguous
()
scales_dout
=
torch
.
stack
(
scales_dout_list
).
contiguous
()
scales_x
=
torch
.
stack
(
scales_x_list
).
contiguous
()
dw
=
torch
.
stack
(
out
).
contiguous
()
dw
=
tex
.
tensorwise_int8_batchgemm
(
qx_data
.
view
(
-
1
,
qx_data
.
size
(
-
1
)),
transa
,
qdout_data
.
view
(
-
1
,
qdout_data
.
size
(
-
1
)),
transb
,
scales_x
,
scales_dout
,
dw
.
view
(
-
1
,
dw
.
size
(
-
1
)),
num_gemms
,
None
,
TE_DType
[
out_dtype
],
bias
,
bias_dtype
,
gelu
,
gelu_input
[
0
],
grad
,
# grad
workspaces
[
0
],
workspaces
[
0
].
shape
[
0
],
accumulate
,
use_split_accumulator
,
)
for
i
in
range
(
num_gemms
):
out
[
i
].
copy_
(
dw
[
i
])
return
out
,
bias
,
gelu_input
if
int8_simulation_fp8
and
(
isinstance
(
A
[
0
],
Float8TensorBase
)
or
isinstance
(
B
[
0
],
Float8TensorBase
)):
assert
len
(
set
(
m_splits
))
==
1
,
"Int8 simulation groupgemm just surpport token pad as same as batchgemm for now."
assert
not
gelu
,
"GELU not supported with int8 simulation groupgemm."
...
...
transformer_engine/pytorch/csrc/extensions.h
View file @
7a923605
...
...
@@ -88,14 +88,6 @@ std::vector<py::object> gemm(py::handle A, bool transa, py::handle B, bool trans
std
::
optional
<
CommOverlapType
>
comm_type
=
std
::
nullopt
,
MaybeTensor
extra_output
=
std
::
nullopt
,
bool
bulk_overlap
=
false
);
std
::
vector
<
py
::
object
>
generic_batchgemm
(
py
::
handle
A
,
bool
transa
,
py
::
handle
B
,
bool
transb
,
py
::
object
D
,
int
batch_count
,
py
::
handle
quantizer
,
std
::
optional
<
DType
>
out_dtype
,
MaybeTensor
bias
,
DType
bias_type
,
bool
gelu
,
MaybeTensor
gelu_in
,
bool
grad
,
at
::
Tensor
workspace
,
size_t
workspaceSize
,
bool
accumulate
,
bool
use_split_accumulator
,
CommOverlapCore
*
comm_overlap
=
nullptr
,
std
::
optional
<
CommOverlapType
>
comm_type
=
std
::
nullopt
,
MaybeTensor
extra_output
=
std
::
nullopt
,
bool
bulk_overlap
=
false
);
void
te_atomic_gemm
(
at
::
Tensor
A
,
at
::
Tensor
A_scale_inverse
,
DType
A_type
,
std
::
vector
<
int64_t
>
A_scaling_mode
,
bool
transa
,
at
::
Tensor
B
,
at
::
Tensor
B_scale_inverse
,
DType
B_type
,
std
::
vector
<
int64_t
>
B_scaling_mode
,
...
...
@@ -113,6 +105,22 @@ std::optional<std::vector<at::Tensor>> te_general_grouped_gemm(
size_t
workspaceSize
,
bool
accumulate
,
bool
use_split_accumulator
,
int
math_sm_count
);
#ifdef __HIP_PLATFORM_AMD__
std
::
vector
<
py
::
object
>
generic_batchgemm
(
py
::
handle
A
,
bool
transa
,
py
::
handle
B
,
bool
transb
,
py
::
object
D
,
int
batch_count
,
py
::
handle
quantizer
,
std
::
optional
<
DType
>
out_dtype
,
MaybeTensor
bias
,
DType
bias_type
,
bool
gelu
,
MaybeTensor
gelu_in
,
bool
grad
,
at
::
Tensor
workspace
,
size_t
workspaceSize
,
bool
accumulate
,
bool
use_split_accumulator
,
CommOverlapCore
*
comm_overlap
=
nullptr
,
std
::
optional
<
CommOverlapType
>
comm_type
=
std
::
nullopt
,
MaybeTensor
extra_output
=
std
::
nullopt
,
bool
bulk_overlap
=
false
);
std
::
vector
<
py
::
object
>
tensorwise_int8_batchgemm
(
py
::
handle
A
,
bool
transa
,
py
::
handle
B
,
bool
transb
,
py
::
handle
A_scales
,
py
::
handle
B_scales
,
py
::
object
D
,
int
batch_count
,
py
::
handle
quantizer
,
std
::
optional
<
DType
>
out_dtype
,
MaybeTensor
bias
,
DType
bias_type
,
bool
gelu
,
MaybeTensor
gelu_in
,
bool
grad
,
at
::
Tensor
workspace
,
size_t
workspaceSize
,
bool
accumulate
,
bool
use_split_accumulator
,
CommOverlapCore
*
comm_overlap
=
nullptr
,
std
::
optional
<
CommOverlapType
>
comm_type
=
std
::
nullopt
,
MaybeTensor
extra_output
=
std
::
nullopt
,
bool
bulk_overlap
=
false
);
void
te_batchgemm
(
std
::
vector
<
at
::
Tensor
>
A
,
at
::
Tensor
A_scale_inverse
,
int
A_offset
,
transformer_engine
::
DType
A_type
,
bool
transa
,
std
::
vector
<
at
::
Tensor
>
B
,
at
::
Tensor
B_scale_inverse
,
int
B_offset
,
transformer_engine
::
DType
B_type
,
...
...
transformer_engine/pytorch/csrc/extensions/gemm.cpp
View file @
7a923605
...
...
@@ -271,138 +271,6 @@ std::vector<py::object> gemm(py::handle A, bool transa, py::handle B, bool trans
return
out
;
}
std
::
vector
<
py
::
object
>
generic_batchgemm
(
py
::
handle
A
,
bool
transa
,
py
::
handle
B
,
bool
transb
,
py
::
object
D
,
int
batch_count
,
py
::
handle
quantizer
,
std
::
optional
<
DType
>
out_dtype
,
MaybeTensor
bias
,
DType
bias_type
,
bool
gelu
,
MaybeTensor
gelu_in
,
bool
grad
,
at
::
Tensor
workspace
,
size_t
workspaceSize
,
bool
accumulate
,
bool
use_split_accumulator
,
CommOverlapCore
*
comm_overlap
,
std
::
optional
<
CommOverlapType
>
comm_type
,
MaybeTensor
extra_output
,
bool
bulk_overlap
)
{
// Input tensors
NVTE_CHECK
(
!
A
.
is_none
(),
"Tensor A has not been provided"
);
NVTE_CHECK
(
!
B
.
is_none
(),
"Tensor B has not been provided"
);
auto
none
=
py
::
none
();
TensorWrapper
A_tensor
=
makeTransformerEngineTensor
(
A
,
none
);
TensorWrapper
B_tensor
=
makeTransformerEngineTensor
(
B
,
none
);
const
bool
low_precision
=
detail
::
is_low_precision
(
A_tensor
.
dtype
())
||
detail
::
is_low_precision
(
B_tensor
.
dtype
());
// Check tensor dimensions
const
auto
&
A_shape
=
A_tensor
.
shape
();
const
auto
&
B_shape
=
B_tensor
.
shape
();
const
auto
&
D_shape
=
detail
::
getGemmOutputShape
(
A_shape
,
transa
,
B_shape
,
transb
);
NVTE_CHECK
(
A_shape
.
ndim
>=
1
,
"Tensor A needs to have at least 1 dimension"
);
NVTE_CHECK
(
B_shape
.
ndim
>=
1
,
"Tensor B needs to have at least 1 dimension"
);
// Output tensor
TensorWrapper
D_tensor
;
if
(
D
.
is_none
())
{
NVTE_ERROR
(
"generic batchgemm D must be not None."
);
}
else
{
D_tensor
=
makeTransformerEngineTensor
(
D
,
quantizer
);
if
(
out_dtype
)
{
NVTE_CHECK
(
*
out_dtype
==
D_tensor
.
dtype
(),
"GEMM output has invalid dtype (expected "
,
static_cast
<
int
>
(
*
out_dtype
),
", found "
,
static_cast
<
int
>
(
D_tensor
.
dtype
()),
")"
);
}
}
// Bias tensor
TensorWrapper
bias_tensor
;
MaybeTensor
bias_grad
=
std
::
nullopt
;
if
(
bias
.
has_value
())
{
if
(
grad
)
{
auto
opts
=
torch
::
TensorOptions
().
dtype
(
GetATenDType
(
D_tensor
.
dtype
())).
device
(
torch
::
kCUDA
);
bias_grad
=
at
::
empty
({
static_cast
<
int64_t
>
(
B_shape
.
data
[
B_shape
.
ndim
-
1
])},
opts
);
bias_tensor
=
makeTransformerEngineTensor
(
*
bias_grad
);
}
else
{
if
(
!
bias
->
is_contiguous
())
{
bias
=
bias
->
contiguous
();
}
bias_tensor
=
makeTransformerEngineTensor
(
*
bias
);
}
}
// Activation input tensor
MaybeTensor
pre_gelu_out
=
std
::
nullopt
;
DType
gelu_type
=
low_precision
?
bias_type
:
D_tensor
.
dtype
();
if
(
gelu
)
{
if
(
!
grad
)
{
auto
dtype
=
GetATenDType
(
gelu_type
);
auto
opts
=
torch
::
TensorOptions
().
dtype
(
dtype
).
device
(
torch
::
kCUDA
);
std
::
vector
<
int64_t
>
torch_shape
;
for
(
auto
v
:
D_shape
)
{
torch_shape
.
push_back
(
v
);
}
pre_gelu_out
=
at
::
empty
(
torch_shape
,
opts
);
}
else
{
if
(
gelu_in
.
has_value
())
{
pre_gelu_out
=
*
gelu_in
;
}
}
}
const
auto
gelu_shape
=
gelu
?
D_shape
:
std
::
vector
<
size_t
>
{
0
};
auto
te_pre_gelu_out
=
makeTransformerEngineTensor
(
get_data_ptr
(
pre_gelu_out
),
gelu_shape
,
gelu_type
);
// Workspace
auto
te_workspace
=
makeTransformerEngineTensor
(
workspace
.
data_ptr
(),
std
::
vector
<
size_t
>
{
workspaceSize
},
DType
::
kByte
);
// Set an external SM Margin to all the GEMMs.
// This comes in handy when DP is overlapped with GEMMs
const
int
device_id
=
at
::
cuda
::
current_device
();
const
int
sm_count
=
transformer_engine
::
cuda
::
sm_count
(
device_id
);
int
num_math_sms
=
sm_count
-
transformer_engine
::
getenv
<
int
>
(
"NVTE_EXT_MARGIN_SM"
,
sm_count
);
// Keep the swizzled scaling factor tensors alive during the GEMM.
std
::
vector
<
std
::
optional
<
at
::
Tensor
>>
swizzled_scale_inverses_list
;
auto
main_stream
=
at
::
cuda
::
getCurrentCUDAStream
();
if
(
A_tensor
.
numel
()
!=
0
&&
B_tensor
.
numel
()
!=
0
)
{
// Optionally swizzle the scaling factors
swizzled_scale_inverses_list
.
emplace_back
(
std
::
move
(
swizzle_scaling_factors
(
A_tensor
,
transa
)));
swizzled_scale_inverses_list
.
emplace_back
(
std
::
move
(
swizzle_scaling_factors
(
B_tensor
,
!
transb
)));
if
(
comm_overlap
)
{
NVTE_ERROR
(
"generic batchgemm not surpport comm_overlap."
);
}
else
{
// Launch GEMM
NVTE_SCOPED_GIL_RELEASE
({
nvte_cublas_batchgemm_v2
(
A_tensor
.
data
(),
B_tensor
.
data
(),
D_tensor
.
data
(),
bias_tensor
.
data
(),
te_pre_gelu_out
.
data
(),
transa
,
transb
,
grad
,
te_workspace
.
data
(),
accumulate
,
use_split_accumulator
,
num_math_sms
,
batch_count
,
main_stream
);
});
}
}
else
{
if
(
D_tensor
.
numel
()
!=
0
&&
!
accumulate
)
{
D_tensor
.
zero_
(
main_stream
);
}
if
(
bias
.
has_value
())
{
if
(
bias
->
numel
()
!=
0
&&
grad
)
{
bias_grad
->
zero_
();
}
}
}
// Pack outputs
std
::
vector
<
py
::
object
>
out
;
out
.
emplace_back
(
std
::
move
(
D
));
out
.
emplace_back
(
py
::
cast
(
bias_grad
));
if
(
gelu
&&
!
grad
)
{
out
.
emplace_back
(
py
::
cast
(
*
pre_gelu_out
));
}
else
{
out
.
emplace_back
(
py
::
none
());
}
if
(
extra_output
.
has_value
())
{
out
.
emplace_back
(
py
::
cast
(
extra_output
));
}
else
{
out
.
emplace_back
(
py
::
none
());
}
return
out
;
}
void
te_atomic_gemm
(
at
::
Tensor
A
,
at
::
Tensor
A_scale_inverse
,
DType
A_type
,
std
::
vector
<
int64_t
>
A_scaling_mode
,
bool
transa
,
at
::
Tensor
B
,
at
::
Tensor
B_scale_inverse
,
DType
B_type
,
std
::
vector
<
int64_t
>
B_scaling_mode
,
...
...
@@ -586,6 +454,274 @@ std::optional<std::vector<at::Tensor>> te_general_grouped_gemm(
}
#ifdef USE_ROCM
std
::
vector
<
py
::
object
>
generic_batchgemm
(
py
::
handle
A
,
bool
transa
,
py
::
handle
B
,
bool
transb
,
py
::
object
D
,
int
batch_count
,
py
::
handle
quantizer
,
std
::
optional
<
DType
>
out_dtype
,
MaybeTensor
bias
,
DType
bias_type
,
bool
gelu
,
MaybeTensor
gelu_in
,
bool
grad
,
at
::
Tensor
workspace
,
size_t
workspaceSize
,
bool
accumulate
,
bool
use_split_accumulator
,
CommOverlapCore
*
comm_overlap
,
std
::
optional
<
CommOverlapType
>
comm_type
,
MaybeTensor
extra_output
,
bool
bulk_overlap
)
{
// Input tensors
NVTE_CHECK
(
!
A
.
is_none
(),
"Tensor A has not been provided"
);
NVTE_CHECK
(
!
B
.
is_none
(),
"Tensor B has not been provided"
);
auto
none
=
py
::
none
();
TensorWrapper
A_tensor
=
makeTransformerEngineTensor
(
A
,
none
);
TensorWrapper
B_tensor
=
makeTransformerEngineTensor
(
B
,
none
);
const
bool
low_precision
=
detail
::
is_low_precision
(
A_tensor
.
dtype
())
||
detail
::
is_low_precision
(
B_tensor
.
dtype
());
// Check tensor dimensions
const
auto
&
A_shape
=
A_tensor
.
shape
();
const
auto
&
B_shape
=
B_tensor
.
shape
();
const
auto
&
D_shape
=
detail
::
getGemmOutputShape
(
A_shape
,
transa
,
B_shape
,
transb
);
NVTE_CHECK
(
A_shape
.
ndim
>=
1
,
"Tensor A needs to have at least 1 dimension"
);
NVTE_CHECK
(
B_shape
.
ndim
>=
1
,
"Tensor B needs to have at least 1 dimension"
);
// Output tensor
TensorWrapper
D_tensor
;
if
(
D
.
is_none
())
{
NVTE_ERROR
(
"generic batchgemm D must be not None."
);
}
else
{
D_tensor
=
makeTransformerEngineTensor
(
D
,
quantizer
);
if
(
out_dtype
)
{
NVTE_CHECK
(
*
out_dtype
==
D_tensor
.
dtype
(),
"GEMM output has invalid dtype (expected "
,
static_cast
<
int
>
(
*
out_dtype
),
", found "
,
static_cast
<
int
>
(
D_tensor
.
dtype
()),
")"
);
}
}
// Bias tensor
TensorWrapper
bias_tensor
;
MaybeTensor
bias_grad
=
std
::
nullopt
;
if
(
bias
.
has_value
())
{
if
(
grad
)
{
auto
opts
=
torch
::
TensorOptions
().
dtype
(
GetATenDType
(
D_tensor
.
dtype
())).
device
(
torch
::
kCUDA
);
bias_grad
=
at
::
empty
({
static_cast
<
int64_t
>
(
B_shape
.
data
[
B_shape
.
ndim
-
1
])},
opts
);
bias_tensor
=
makeTransformerEngineTensor
(
*
bias_grad
);
}
else
{
if
(
!
bias
->
is_contiguous
())
{
bias
=
bias
->
contiguous
();
}
bias_tensor
=
makeTransformerEngineTensor
(
*
bias
);
}
}
// Activation input tensor
MaybeTensor
pre_gelu_out
=
std
::
nullopt
;
DType
gelu_type
=
low_precision
?
bias_type
:
D_tensor
.
dtype
();
if
(
gelu
)
{
if
(
!
grad
)
{
auto
dtype
=
GetATenDType
(
gelu_type
);
auto
opts
=
torch
::
TensorOptions
().
dtype
(
dtype
).
device
(
torch
::
kCUDA
);
std
::
vector
<
int64_t
>
torch_shape
;
for
(
auto
v
:
D_shape
)
{
torch_shape
.
push_back
(
v
);
}
pre_gelu_out
=
at
::
empty
(
torch_shape
,
opts
);
}
else
{
if
(
gelu_in
.
has_value
())
{
pre_gelu_out
=
*
gelu_in
;
}
}
}
const
auto
gelu_shape
=
gelu
?
D_shape
:
std
::
vector
<
size_t
>
{
0
};
auto
te_pre_gelu_out
=
makeTransformerEngineTensor
(
get_data_ptr
(
pre_gelu_out
),
gelu_shape
,
gelu_type
);
// Workspace
auto
te_workspace
=
makeTransformerEngineTensor
(
workspace
.
data_ptr
(),
std
::
vector
<
size_t
>
{
workspaceSize
},
DType
::
kByte
);
// Set an external SM Margin to all the GEMMs.
// This comes in handy when DP is overlapped with GEMMs
const
int
device_id
=
at
::
cuda
::
current_device
();
const
int
sm_count
=
transformer_engine
::
cuda
::
sm_count
(
device_id
);
int
num_math_sms
=
sm_count
-
transformer_engine
::
getenv
<
int
>
(
"NVTE_EXT_MARGIN_SM"
,
sm_count
);
// Keep the swizzled scaling factor tensors alive during the GEMM.
std
::
vector
<
std
::
optional
<
at
::
Tensor
>>
swizzled_scale_inverses_list
;
auto
main_stream
=
at
::
cuda
::
getCurrentCUDAStream
();
if
(
A_tensor
.
numel
()
!=
0
&&
B_tensor
.
numel
()
!=
0
)
{
// Optionally swizzle the scaling factors
swizzled_scale_inverses_list
.
emplace_back
(
std
::
move
(
swizzle_scaling_factors
(
A_tensor
,
transa
)));
swizzled_scale_inverses_list
.
emplace_back
(
std
::
move
(
swizzle_scaling_factors
(
B_tensor
,
!
transb
)));
if
(
comm_overlap
)
{
NVTE_ERROR
(
"generic batchgemm not surpport comm_overlap."
);
}
else
{
// Launch GEMM
NVTE_SCOPED_GIL_RELEASE
({
nvte_cublas_batchgemm_v2
(
A_tensor
.
data
(),
B_tensor
.
data
(),
D_tensor
.
data
(),
bias_tensor
.
data
(),
te_pre_gelu_out
.
data
(),
transa
,
transb
,
grad
,
te_workspace
.
data
(),
accumulate
,
use_split_accumulator
,
num_math_sms
,
batch_count
,
main_stream
);
});
}
}
else
{
if
(
D_tensor
.
numel
()
!=
0
&&
!
accumulate
)
{
D_tensor
.
zero_
(
main_stream
);
}
if
(
bias
.
has_value
())
{
if
(
bias
->
numel
()
!=
0
&&
grad
)
{
bias_grad
->
zero_
();
}
}
}
// Pack outputs
std
::
vector
<
py
::
object
>
out
;
out
.
emplace_back
(
std
::
move
(
D
));
out
.
emplace_back
(
py
::
cast
(
bias_grad
));
if
(
gelu
&&
!
grad
)
{
out
.
emplace_back
(
py
::
cast
(
*
pre_gelu_out
));
}
else
{
out
.
emplace_back
(
py
::
none
());
}
if
(
extra_output
.
has_value
())
{
out
.
emplace_back
(
py
::
cast
(
extra_output
));
}
else
{
out
.
emplace_back
(
py
::
none
());
}
return
out
;
}
std
::
vector
<
py
::
object
>
tensorwise_int8_batchgemm
(
py
::
handle
A
,
bool
transa
,
py
::
handle
B
,
bool
transb
,
py
::
handle
A_scales
,
py
::
handle
B_scales
,
py
::
object
D
,
int
batch_count
,
py
::
handle
quantizer
,
std
::
optional
<
DType
>
out_dtype
,
MaybeTensor
bias
,
DType
bias_type
,
bool
gelu
,
MaybeTensor
gelu_in
,
bool
grad
,
at
::
Tensor
workspace
,
size_t
workspaceSize
,
bool
accumulate
,
bool
use_split_accumulator
,
CommOverlapCore
*
comm_overlap
,
std
::
optional
<
CommOverlapType
>
comm_type
,
MaybeTensor
extra_output
,
bool
bulk_overlap
)
{
// Input tensors
NVTE_CHECK
(
!
A
.
is_none
(),
"Tensor A has not been provided"
);
NVTE_CHECK
(
!
B
.
is_none
(),
"Tensor B has not been provided"
);
NVTE_CHECK
(
!
A_scales
.
is_none
(),
"Tensor A has not been provided"
);
NVTE_CHECK
(
!
B_scales
.
is_none
(),
"Tensor B has not been provided"
);
auto
none
=
py
::
none
();
TensorWrapper
A_tensor
=
makeTransformerEngineTensor
(
A
,
none
);
TensorWrapper
B_tensor
=
makeTransformerEngineTensor
(
B
,
none
);
TensorWrapper
A_scales_tensor
=
makeTransformerEngineTensor
(
A_scales
,
none
);
TensorWrapper
B_scales_tensor
=
makeTransformerEngineTensor
(
B_scales
,
none
);
const
bool
low_precision
=
detail
::
is_low_precision
(
A_tensor
.
dtype
())
||
detail
::
is_low_precision
(
B_tensor
.
dtype
());
// Check tensor dimensions
const
auto
&
A_shape
=
A_tensor
.
shape
();
const
auto
&
B_shape
=
B_tensor
.
shape
();
const
auto
&
D_shape
=
detail
::
getGemmOutputShape
(
A_shape
,
transa
,
B_shape
,
transb
);
NVTE_CHECK
(
A_shape
.
ndim
>=
1
,
"Tensor A needs to have at least 1 dimension"
);
NVTE_CHECK
(
B_shape
.
ndim
>=
1
,
"Tensor B needs to have at least 1 dimension"
);
// Output tensor
TensorWrapper
D_tensor
;
if
(
D
.
is_none
())
{
NVTE_ERROR
(
"tensorwise int8 batchgemm D must be not None."
);
}
else
{
D_tensor
=
makeTransformerEngineTensor
(
D
,
quantizer
);
if
(
out_dtype
)
{
NVTE_CHECK
(
*
out_dtype
==
D_tensor
.
dtype
(),
"GEMM output has invalid dtype (expected "
,
static_cast
<
int
>
(
*
out_dtype
),
", found "
,
static_cast
<
int
>
(
D_tensor
.
dtype
()),
")"
);
}
}
// Bias tensor
TensorWrapper
bias_tensor
;
MaybeTensor
bias_grad
=
std
::
nullopt
;
if
(
bias
.
has_value
())
{
if
(
grad
)
{
auto
opts
=
torch
::
TensorOptions
().
dtype
(
GetATenDType
(
D_tensor
.
dtype
())).
device
(
torch
::
kCUDA
);
bias_grad
=
at
::
empty
({
static_cast
<
int64_t
>
(
B_shape
.
data
[
B_shape
.
ndim
-
1
])},
opts
);
bias_tensor
=
makeTransformerEngineTensor
(
*
bias_grad
);
}
else
{
if
(
!
bias
->
is_contiguous
())
{
bias
=
bias
->
contiguous
();
}
bias_tensor
=
makeTransformerEngineTensor
(
*
bias
);
}
}
// Activation input tensor
MaybeTensor
pre_gelu_out
=
std
::
nullopt
;
DType
gelu_type
=
low_precision
?
bias_type
:
D_tensor
.
dtype
();
if
(
gelu
)
{
if
(
!
grad
)
{
auto
dtype
=
GetATenDType
(
gelu_type
);
auto
opts
=
torch
::
TensorOptions
().
dtype
(
dtype
).
device
(
torch
::
kCUDA
);
std
::
vector
<
int64_t
>
torch_shape
;
for
(
auto
v
:
D_shape
)
{
torch_shape
.
push_back
(
v
);
}
pre_gelu_out
=
at
::
empty
(
torch_shape
,
opts
);
}
else
{
if
(
gelu_in
.
has_value
())
{
pre_gelu_out
=
*
gelu_in
;
}
}
}
const
auto
gelu_shape
=
gelu
?
D_shape
:
std
::
vector
<
size_t
>
{
0
};
auto
te_pre_gelu_out
=
makeTransformerEngineTensor
(
get_data_ptr
(
pre_gelu_out
),
gelu_shape
,
gelu_type
);
// Workspace
auto
te_workspace
=
makeTransformerEngineTensor
(
workspace
.
data_ptr
(),
std
::
vector
<
size_t
>
{
workspaceSize
},
DType
::
kByte
);
// Set an external SM Margin to all the GEMMs.
// This comes in handy when DP is overlapped with GEMMs
const
int
device_id
=
at
::
cuda
::
current_device
();
const
int
sm_count
=
transformer_engine
::
cuda
::
sm_count
(
device_id
);
int
num_math_sms
=
sm_count
-
transformer_engine
::
getenv
<
int
>
(
"NVTE_EXT_MARGIN_SM"
,
sm_count
);
// Keep the swizzled scaling factor tensors alive during the GEMM.
std
::
vector
<
std
::
optional
<
at
::
Tensor
>>
swizzled_scale_inverses_list
;
auto
main_stream
=
at
::
cuda
::
getCurrentCUDAStream
();
if
(
A_tensor
.
numel
()
!=
0
&&
B_tensor
.
numel
()
!=
0
)
{
// Optionally swizzle the scaling factors
swizzled_scale_inverses_list
.
emplace_back
(
std
::
move
(
swizzle_scaling_factors
(
A_tensor
,
transa
)));
swizzled_scale_inverses_list
.
emplace_back
(
std
::
move
(
swizzle_scaling_factors
(
B_tensor
,
!
transb
)));
if
(
comm_overlap
)
{
NVTE_ERROR
(
"generic batchgemm not surpport comm_overlap."
);
}
else
{
// Launch GEMM
NVTE_SCOPED_GIL_RELEASE
({
nvte_cublas_batchgemm_v3
(
A_tensor
.
data
(),
B_tensor
.
data
(),
A_scales_tensor
.
data
(),
B_scales_tensor
.
data
(),
D_tensor
.
data
(),
bias_tensor
.
data
(),
te_pre_gelu_out
.
data
(),
transa
,
transb
,
grad
,
te_workspace
.
data
(),
accumulate
,
use_split_accumulator
,
num_math_sms
,
batch_count
,
main_stream
);
});
}
}
else
{
if
(
D_tensor
.
numel
()
!=
0
&&
!
accumulate
)
{
D_tensor
.
zero_
(
main_stream
);
}
if
(
bias
.
has_value
())
{
if
(
bias
->
numel
()
!=
0
&&
grad
)
{
bias_grad
->
zero_
();
}
}
}
// Pack outputs
std
::
vector
<
py
::
object
>
out
;
out
.
emplace_back
(
std
::
move
(
D
));
out
.
emplace_back
(
py
::
cast
(
bias_grad
));
if
(
gelu
&&
!
grad
)
{
out
.
emplace_back
(
py
::
cast
(
*
pre_gelu_out
));
}
else
{
out
.
emplace_back
(
py
::
none
());
}
if
(
extra_output
.
has_value
())
{
out
.
emplace_back
(
py
::
cast
(
extra_output
));
}
else
{
out
.
emplace_back
(
py
::
none
());
}
return
out
;
}
void
te_batchgemm
(
std
::
vector
<
at
::
Tensor
>
A
,
at
::
Tensor
A_scale_inverse
,
int
A_offset
,
transformer_engine
::
DType
A_type
,
bool
transa
,
std
::
vector
<
at
::
Tensor
>
B
,
at
::
Tensor
B_scale_inverse
,
int
B_offset
,
transformer_engine
::
DType
B_type
,
...
...
transformer_engine/pytorch/csrc/extensions/pybind.cpp
View file @
7a923605
...
...
@@ -110,13 +110,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
py
::
arg
(
"workspace_size"
),
py
::
arg
(
"accumulate"
),
py
::
arg
(
"use_split_accumulator"
),
py
::
arg
(
"comm_overlap"
)
=
nullptr
,
py
::
arg
(
"comm_type"
)
=
std
::
nullopt
,
py
::
arg
(
"extra_output"
)
=
std
::
nullopt
,
py
::
arg
(
"bulk_overlap"
)
=
false
);
m
.
def
(
"generic_batchgemm"
,
transformer_engine
::
pytorch
::
generic_batchgemm
,
"Compute Batched GEMM (matrix-matrix multiply)"
,
py
::
arg
(
"A"
),
py
::
arg
(
"transA"
),
py
::
arg
(
"B"
),
py
::
arg
(
"transB"
),
py
::
arg
(
"D"
),
py
::
arg
(
"batchcount"
),
py
::
arg
(
"quantizer"
),
py
::
arg
(
"output_dtype"
),
py
::
arg
(
"bias"
),
py
::
arg
(
"bias_type"
),
py
::
arg
(
"gelu"
),
py
::
arg
(
"gelu_in"
),
py
::
arg
(
"grad"
),
py
::
arg
(
"workspace"
),
py
::
arg
(
"workspace_size"
),
py
::
arg
(
"accumulate"
),
py
::
arg
(
"use_split_accumulator"
),
py
::
arg
(
"comm_overlap"
)
=
nullptr
,
py
::
arg
(
"comm_type"
)
=
std
::
nullopt
,
py
::
arg
(
"extra_output"
)
=
std
::
nullopt
,
py
::
arg
(
"bulk_overlap"
)
=
false
);
m
.
def
(
"gelu"
,
transformer_engine
::
pytorch
::
gelu
,
"GeLU activation"
,
py
::
arg
(
"input"
),
py
::
arg
(
"quantizer"
));
m
.
def
(
"relu"
,
transformer_engine
::
pytorch
::
relu
,
"ReLU activation"
,
py
::
arg
(
"input"
),
...
...
@@ -213,6 +206,20 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m
.
def
(
"te_general_grouped_gemm"
,
&
transformer_engine
::
pytorch
::
te_general_grouped_gemm
,
"Grouped GEMM"
);
#ifdef USE_ROCM
m
.
def
(
"generic_batchgemm"
,
transformer_engine
::
pytorch
::
generic_batchgemm
,
"Compute Batched GEMM (matrix-matrix multiply)"
,
py
::
arg
(
"A"
),
py
::
arg
(
"transA"
),
py
::
arg
(
"B"
),
py
::
arg
(
"transB"
),
py
::
arg
(
"D"
),
py
::
arg
(
"batchcount"
),
py
::
arg
(
"quantizer"
),
py
::
arg
(
"output_dtype"
),
py
::
arg
(
"bias"
),
py
::
arg
(
"bias_type"
),
py
::
arg
(
"gelu"
),
py
::
arg
(
"gelu_in"
),
py
::
arg
(
"grad"
),
py
::
arg
(
"workspace"
),
py
::
arg
(
"workspace_size"
),
py
::
arg
(
"accumulate"
),
py
::
arg
(
"use_split_accumulator"
),
py
::
arg
(
"comm_overlap"
)
=
nullptr
,
py
::
arg
(
"comm_type"
)
=
std
::
nullopt
,
py
::
arg
(
"extra_output"
)
=
std
::
nullopt
,
py
::
arg
(
"bulk_overlap"
)
=
false
);
m
.
def
(
"tensorwise_int8_batchgemm"
,
transformer_engine
::
pytorch
::
tensorwise_int8_batchgemm
,
"Compute Tensorwise Int8 Batched GEMM (matrix-matrix multiply)"
,
py
::
arg
(
"A"
),
py
::
arg
(
"transA"
),
py
::
arg
(
"B"
),
py
::
arg
(
"transB"
),
py
::
arg
(
"A_scales"
),
py
::
arg
(
"B_scales"
),
py
::
arg
(
"D"
),
py
::
arg
(
"batchcount"
),
py
::
arg
(
"quantizer"
),
py
::
arg
(
"output_dtype"
),
py
::
arg
(
"bias"
),
py
::
arg
(
"bias_type"
),
py
::
arg
(
"gelu"
),
py
::
arg
(
"gelu_in"
),
py
::
arg
(
"grad"
),
py
::
arg
(
"workspace"
),
py
::
arg
(
"workspace_size"
),
py
::
arg
(
"accumulate"
),
py
::
arg
(
"use_split_accumulator"
),
py
::
arg
(
"comm_overlap"
)
=
nullptr
,
py
::
arg
(
"comm_type"
)
=
std
::
nullopt
,
py
::
arg
(
"extra_output"
)
=
std
::
nullopt
,
py
::
arg
(
"bulk_overlap"
)
=
false
);
m
.
def
(
"te_batchgemm_ts"
,
&
transformer_engine
::
pytorch
::
te_batchgemm_ts
,
"Batched GEMM"
);
/// rocblas
#endif
m
.
def
(
"fp8_transpose"
,
&
transformer_engine
::
pytorch
::
fp8_transpose
,
"Transpose with FP8 I/O"
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment