Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
7a923605
Commit
7a923605
authored
Aug 21, 2025
by
yuguo
Browse files
[DCU] tensorwise int8 train opt
parent
686e93cd
Changes
9
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
1277 additions
and
272 deletions
+1277
-272
tests/pytorch/test_int8_channelwise_gemm_exact.py
tests/pytorch/test_int8_channelwise_gemm_exact.py
+271
-74
transformer_engine/common/gemm/cublaslt_gemm.cu
transformer_engine/common/gemm/cublaslt_gemm.cu
+76
-1
transformer_engine/common/gemm/rocm_gemm.cu
transformer_engine/common/gemm/rocm_gemm.cu
+458
-3
transformer_engine/common/include/transformer_engine/gemm.h
transformer_engine/common/include/transformer_engine/gemm.h
+5
-0
transformer_engine/common/transformer_engine.cpp
transformer_engine/common/transformer_engine.cpp
+1
-1
transformer_engine/pytorch/cpp_extensions/gemm.py
transformer_engine/pytorch/cpp_extensions/gemm.py
+168
-46
transformer_engine/pytorch/csrc/extensions.h
transformer_engine/pytorch/csrc/extensions.h
+16
-8
transformer_engine/pytorch/csrc/extensions/gemm.cpp
transformer_engine/pytorch/csrc/extensions/gemm.cpp
+268
-132
transformer_engine/pytorch/csrc/extensions/pybind.cpp
transformer_engine/pytorch/csrc/extensions/pybind.cpp
+14
-7
No files found.
tests/pytorch/test_int8_channelwise_gemm_exact.py
View file @
7a923605
...
@@ -32,11 +32,56 @@ import os
...
@@ -32,11 +32,56 @@ import os
int8_simulation_fp8_tensorwise
=
bool
(
int
(
os
.
getenv
(
"NVTE_INT8_SIM_FP8_TENSORWISE"
,
"0"
)))
int8_simulation_fp8_tensorwise
=
bool
(
int
(
os
.
getenv
(
"NVTE_INT8_SIM_FP8_TENSORWISE"
,
"0"
)))
tensorwise_int8_check
=
bool
(
int
(
os
.
getenv
(
"NVTE_INT8_SIM_FP8_TENSORWISE_CHECK"
,
"0"
)))
def
dtype_tols
(
dtype
:
torch
.
dtype
)
->
Dict
[
str
,
float
]:
"""Estimated numerical error for a datatype
Based on tolerances for torch.testing.assert_close.
"""
if
dtype
==
torch
.
float32
:
return
dict
(
rtol
=
1.3e-6
,
atol
=
1e-5
)
if
dtype
==
torch
.
float16
:
return
dict
(
rtol
=
1e-3
,
atol
=
1e-5
)
if
dtype
==
torch
.
bfloat16
:
return
dict
(
rtol
=
1.6e-2
,
atol
=
1e-5
)
raise
ValueError
(
f
"Unsuppored dtype (
{
dtype
}
)"
)
def
assert_allclose
(
l1
:
List
[
torch
.
Tensor
],
l2
:
List
[
torch
.
Tensor
],
atol
:
float
=
None
,
rtol
:
float
=
None
)
->
bool
:
"""Ensures two lists are equal."""
assert
len
(
l1
)
==
len
(
l2
),
"Unequal number of outputs."
for
i
,
(
t1
,
t2
)
in
enumerate
(
zip
(
l1
,
l2
)):
tols
=
dtype_tols
(
t2
.
dtype
)
if
rtol
is
not
None
:
tols
[
"rtol"
]
=
rtol
if
atol
is
not
None
:
tols
[
"atol"
]
=
atol
result
=
torch
.
allclose
(
t1
,
t2
,
**
tols
)
if
not
result
:
diff
=
torch
.
abs
(
t1
-
t2
)
tol
=
tols
[
"atol"
]
+
(
tols
[
"rtol"
]
*
torch
.
abs
(
t2
))
exceed_mask
=
diff
>
tol
if
exceed_mask
.
any
():
indices
=
torch
.
nonzero
(
exceed_mask
,
as_tuple
=
True
)
max_diff
=
diff
[
exceed_mask
].
max
()
max_idx
=
(
diff
[
exceed_mask
]
==
max_diff
).
nonzero
(
as_tuple
=
True
)[
0
][
0
]
max_location
=
[
idx
[
max_idx
].
item
()
for
idx
in
indices
]
msg
=
(
f
"Outputs not close enough in tensor at idx=
{
i
}
. "
f
"Maximum difference at location
{
max_location
}
"
f
"with
{
t1
[
exceed_mask
][
max_idx
].
item
()
}
vs
{
t2
[
exceed_mask
][
max_idx
].
item
()
}
"
f
"(diff
{
max_diff
.
item
()
}
)."
)
raise
AssertionError
(
msg
)
# TN
# TN
m
=
4096
m
=
4096
k
=
4096
k
=
4096
n
=
4096
n
=
6144
seed
=
0
seed
=
4096
torch
.
manual_seed
(
seed
)
torch
.
manual_seed
(
seed
)
torch
.
cuda
.
manual_seed
(
seed
)
torch
.
cuda
.
manual_seed
(
seed
)
device
=
"cuda"
device
=
"cuda"
...
@@ -235,40 +280,42 @@ transb = False
...
@@ -235,40 +280,42 @@ transb = False
x_bf16
=
(
torch
.
randn
((
m
,
k
),
device
=
device
)).
to
(
dtype
=
torch
.
bfloat16
)
x_bf16
=
(
torch
.
randn
((
m
,
k
),
device
=
device
)).
to
(
dtype
=
torch
.
bfloat16
)
w_bf16
=
(
torch
.
randn
((
n
,
k
),
device
=
device
)).
to
(
dtype
=
torch
.
bfloat16
)
w_bf16
=
(
torch
.
randn
((
n
,
k
),
device
=
device
)).
to
(
dtype
=
torch
.
bfloat16
)
output
=
(
torch
.
randn
((
m
,
n
),
device
=
device
)).
to
(
dtype
=
torch
.
bfloat16
)
output
=
(
torch
.
randn
((
m
,
n
),
device
=
device
)).
to
(
dtype
=
torch
.
bfloat16
)
for
i
in
range
(
20
):
bf16_out
=
torch
.
matmul
(
x_bf16
,
w_bf16
.
t
())
bf16_out
=
torch
.
matmul
(
x_bf16
,
w_bf16
.
t
())
print
(
"bf16_out: "
,
bf16_out
)
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
start
=
time
.
time
()
start
=
time
.
time
()
for
i
in
range
(
20
):
for
i
in
range
(
20
):
bf16_out
=
torch
.
matmul
(
x_bf16
,
w_bf16
.
t
())
bf16_out
=
torch
.
matmul
(
x_bf16
,
w_bf16
.
t
())
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
end
=
time
.
time
()
end
=
time
.
time
()
# x_int8, x_scales = per_token_quant_int8(x_bf16)
# x_int8, x_scales = per_token_quant_int8(x_bf16)
# w_int8, w_scales = per_token_quant_int8(w_bf16)
# w_int8, w_scales = per_token_quant_int8(w_bf16)
# print("x_int8: ", x_int8)
# print("x_int8: ", x_int8)
# print("w_int8: ", w_int8)
# print("w_int8: ", w_int8)
# Cast to FP8 and back
# Cast to FP8 and back
x_fp8
=
to_float8_CS
(
x_bf16
,
fp8_dtype
=
tex
.
DType
.
kFloat8E4M3
)
x_fp8
=
to_float8_CS
(
x_bf16
,
fp8_dtype
=
tex
.
DType
.
kFloat8E4M3
)
w_fp8
=
to_float8_CS
(
w_bf16
,
fp8_dtype
=
tex
.
DType
.
kFloat8E4M3
)
w_fp8
=
to_float8_CS
(
w_bf16
,
fp8_dtype
=
tex
.
DType
.
kFloat8E4M3
)
# print("x_fp8: ", x_fp8._data.view(dtype=torch.float8_e4m3fn))
# print("x_fp8: ", x_fp8._data.view(dtype=torch.float8_e4m3fn))
# print("w_fp8: ", w_fp8._data.view(dtype=torch.float8_e4m3fn))
# print("w_fp8: ", w_fp8._data.view(dtype=torch.float8_e4m3fn))
if
int8_simulation_fp8_tensorwise
:
if
int8_simulation_fp8_tensorwise
:
x_int8
,
x_scales
=
x_fp8
.
_data
.
view
(
dtype
=
torch
.
int8
),
x_fp8
.
_scale_inv
x_int8
,
x_scales
=
x_fp8
.
_data
.
view
(
dtype
=
torch
.
int8
),
x_fp8
.
_scale_inv
w_int8
,
w_scales
=
w_fp8
.
_data
.
view
(
dtype
=
torch
.
int8
),
w_fp8
.
_scale_inv
w_int8
,
w_scales
=
w_fp8
.
_data
.
view
(
dtype
=
torch
.
int8
),
w_fp8
.
_scale_inv
else
:
else
:
x_int8
,
x_scales
=
per_token_quant_fp8_to_int8
(
x_fp8
.
_data
.
view
(
dtype
=
torch
.
float8_e4m3fn
),
x_fp8
.
_scale_inv
,
False
)
x_int8
,
x_scales
=
per_token_quant_fp8_to_int8
(
x_fp8
.
_data
.
view
(
dtype
=
torch
.
float8_e4m3fn
),
x_fp8
.
_scale_inv
,
False
)
w_int8
,
w_scales
=
per_token_quant_fp8_to_int8
(
w_fp8
.
_data
.
view
(
dtype
=
torch
.
float8_e4m3fn
),
w_fp8
.
_scale_inv
,
False
)
w_int8
,
w_scales
=
per_token_quant_fp8_to_int8
(
w_fp8
.
_data
.
view
(
dtype
=
torch
.
float8_e4m3fn
),
w_fp8
.
_scale_inv
,
False
)
# print("x_int8: ", x_int8)
# print("x_int8: ", x_int8)
# print("w_int8: ", w_int8)
# print("w_int8: ", w_int8)
# cuBLAS GEMM
# cuBLAS GEMM
# return type is out, bias_grad, gelu_input, extra_output
# return type is out, bias_grad, gelu_input, extra_output
# We are just capturing out.
# We are just capturing out.
y_int32
=
tex
.
generic_gemm
(
y_int32
=
tex
.
generic_gemm
(
w_int8
,
w_int8
,
transa
,
transa
,
x_int8
,
x_int8
,
...
@@ -285,19 +332,42 @@ y_int32 = tex.generic_gemm(
...
@@ -285,19 +332,42 @@ y_int32 = tex.generic_gemm(
workspace
.
shape
[
0
],
workspace
.
shape
[
0
],
accumulate
,
accumulate
,
use_split_accumulator
,
use_split_accumulator
,
)[
0
]
)[
0
]
# y_int32 = torch._int_mm(x_int8, w_int8.t())
# y_int32 = torch._int_mm(x_int8, w_int8.t())
# print("y_int32: ", y_int32)
# print("y_int32: ", y_int32)
if
int8_simulation_fp8_tensorwise
:
if
int8_simulation_fp8_tensorwise
:
tensorwise_dequantize
(
x_scales
,
w_scales
,
y_int32
,
output
)
tensorwise_dequantize
(
x_scales
,
w_scales
,
y_int32
,
output
)
else
:
else
:
output
=
channelwise_dequantize_transB
(
x_scales
,
w_scales
,
y_int32
)
output
=
channelwise_dequantize_transB
(
x_scales
,
w_scales
,
y_int32
)
# print("out_scales.shape: ", out_scales.shape)
print
(
"output: "
,
output
)
# print("out_scales: ", out_scales)
print
(
"bf16_out: "
,
bf16_out
)
if
tensorwise_int8_check
:
print
(
"output: "
,
output
)
lt_output
=
tex
.
generic_gemm
(
w_fp8
,
transa
,
x_fp8
,
transb
,
out
,
out_quantizer
,
TE_DType
[
torch
.
bfloat16
],
bias
,
bias_dtype
,
use_gelu
,
aux_tensor
,
use_grad
,
workspace
,
workspace
.
shape
[
0
],
accumulate
,
True
,
)[
0
]
print
(
"lt_output: "
,
lt_output
)
assert_allclose
([
output
],
[
lt_output
])
# print("out_scales.shape: ", out_scales.shape)
# print("out_scales: ", out_scales)
# torch.cuda.synchronize()
# torch.cuda.synchronize()
# start = time.time()
# start = time.time()
...
@@ -339,6 +409,7 @@ w_bf16 = (torch.randn((n, k), device=device)).to(dtype=torch.bfloat16)
...
@@ -339,6 +409,7 @@ w_bf16 = (torch.randn((n, k), device=device)).to(dtype=torch.bfloat16)
dx
=
(
torch
.
randn
((
m
,
k
),
device
=
device
)).
to
(
dtype
=
torch
.
bfloat16
)
dx
=
(
torch
.
randn
((
m
,
k
),
device
=
device
)).
to
(
dtype
=
torch
.
bfloat16
)
bf16_dx
=
torch
.
matmul
(
dy_bf16
,
w_bf16
)
bf16_dx
=
torch
.
matmul
(
dy_bf16
,
w_bf16
)
print
(
"bf16_dx: "
,
bf16_dx
)
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
start
=
time
.
time
()
start
=
time
.
time
()
...
@@ -397,11 +468,32 @@ if int8_simulation_fp8_tensorwise:
...
@@ -397,11 +468,32 @@ if int8_simulation_fp8_tensorwise:
else
:
else
:
dx
=
channelwise_dequantize
(
dy_scales
,
w_scales
,
dx_int32
)
dx
=
channelwise_dequantize
(
dy_scales
,
w_scales
,
dx_int32
)
# dx = channelwise_dequantize_transB(dy_scales, w_scales, dx_int32)
# dx = channelwise_dequantize_transB(dy_scales, w_scales, dx_int32)
print
(
"dx: "
,
dx
)
if
tensorwise_int8_check
:
lt_dx
=
tex
.
generic_gemm
(
w_fp8
,
transa
,
dy_fp8
,
transb
,
out
,
out_quantizer
,
TE_DType
[
torch
.
bfloat16
],
bias
,
bias_dtype
,
use_gelu
,
aux_tensor
,
use_grad
,
workspace
,
workspace
.
shape
[
0
],
accumulate
,
True
,
)[
0
]
print
(
"lt_dx: "
,
lt_dx
)
assert_allclose
([
dx
],
[
lt_dx
])
# print("dx_scales.shape: ", dx_scales.shape)
# print("dx_scales.shape: ", dx_scales.shape)
# print("dx_scales: ", dx_scales)
# print("dx_scales: ", dx_scales)
print
(
"bf16_dx: "
,
bf16_dx
)
print
(
"dx: "
,
dx
)
# torch.cuda.synchronize()
# torch.cuda.synchronize()
# start = time.time()
# start = time.time()
...
@@ -447,11 +539,9 @@ transb = True
...
@@ -447,11 +539,9 @@ transb = True
dy_bf16
=
(
torch
.
randn
((
m
,
n
),
device
=
device
)).
to
(
dtype
=
torch
.
bfloat16
)
dy_bf16
=
(
torch
.
randn
((
m
,
n
),
device
=
device
)).
to
(
dtype
=
torch
.
bfloat16
)
x_bf16
=
(
torch
.
randn
((
m
,
k
),
device
=
device
)).
to
(
dtype
=
torch
.
bfloat16
)
x_bf16
=
(
torch
.
randn
((
m
,
k
),
device
=
device
)).
to
(
dtype
=
torch
.
bfloat16
)
dw
=
(
torch
.
randn
((
n
,
k
),
device
=
device
)).
to
(
dtype
=
torch
.
bfloat16
)
dw
=
(
torch
.
randn
((
n
,
k
),
device
=
device
)).
to
(
dtype
=
torch
.
bfloat16
)
seed
=
0
torch
.
manual_seed
(
seed
)
torch
.
cuda
.
manual_seed
(
seed
)
bf16_dw
=
torch
.
matmul
(
dy_bf16
.
t
(),
x_bf16
)
bf16_dw
=
torch
.
matmul
(
dy_bf16
.
t
(),
x_bf16
)
print
(
"bf16_dw: "
,
bf16_dw
)
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
start
=
time
.
time
()
start
=
time
.
time
()
...
@@ -504,9 +594,30 @@ if int8_simulation_fp8_tensorwise:
...
@@ -504,9 +594,30 @@ if int8_simulation_fp8_tensorwise:
else
:
else
:
dw
=
channelwise_dequantize_transA
(
dy_scales
,
x_scales
,
dw_int32
)
dw
=
channelwise_dequantize_transA
(
dy_scales
,
x_scales
,
dw_int32
)
# dw = channelwise_dequantize_transB(dy_scales, x_scales, dw_int32)
# dw = channelwise_dequantize_transB(dy_scales, x_scales, dw_int32)
print
(
"bf16_dw: "
,
bf16_dw
)
print
(
"dw: "
,
dw
)
print
(
"dw: "
,
dw
)
if
tensorwise_int8_check
:
lt_dw
=
tex
.
generic_gemm
(
x_fp8
,
transa
,
dy_fp8
,
transb
,
out
,
out_quantizer
,
TE_DType
[
torch
.
bfloat16
],
bias
,
bias_dtype
,
use_gelu
,
aux_tensor
,
use_grad
,
workspace
,
workspace
.
shape
[
0
],
accumulate
,
True
,
)[
0
]
print
(
"lt_dw: "
,
lt_dw
)
assert_allclose
([
dw
],
[
lt_dw
])
# torch.cuda.synchronize()
# torch.cuda.synchronize()
# start = time.time()
# start = time.time()
# for i in range(20):
# for i in range(20):
...
@@ -548,9 +659,9 @@ print("dw: ", dw)
...
@@ -548,9 +659,9 @@ print("dw: ", dw)
# bacth gemm wgrad
# bacth gemm wgrad
m
=
32
m
=
1024
k
=
32
k
=
1024
n
=
32
n
=
1024
b
=
4
b
=
4
transa
=
False
transa
=
False
...
@@ -558,9 +669,6 @@ transb = True
...
@@ -558,9 +669,6 @@ transb = True
dy_int8
=
(
torch
.
randn
((
b
,
m
,
n
),
device
=
device
)).
to
(
dtype
=
torch
.
int8
)
dy_int8
=
(
torch
.
randn
((
b
,
m
,
n
),
device
=
device
)).
to
(
dtype
=
torch
.
int8
)
x_int8
=
(
torch
.
randn
((
b
,
m
,
k
),
device
=
device
)).
to
(
dtype
=
torch
.
int8
)
x_int8
=
(
torch
.
randn
((
b
,
m
,
k
),
device
=
device
)).
to
(
dtype
=
torch
.
int8
)
seed
=
0
torch
.
manual_seed
(
seed
)
torch
.
cuda
.
manual_seed
(
seed
)
int32_dw_list
=
[]
int32_dw_list
=
[]
for
i
in
range
(
b
):
for
i
in
range
(
b
):
...
@@ -597,3 +705,92 @@ te_dw = tex.generic_batchgemm(
...
@@ -597,3 +705,92 @@ te_dw = tex.generic_batchgemm(
# print("te_dw.shape: ", te_dw.view(b, -1, te_dw.size(-1)).shape)
# print("te_dw.shape: ", te_dw.view(b, -1, te_dw.size(-1)).shape)
# print("te_dw: ", te_dw.view(b, -1, te_dw.size(-1)))
# print("te_dw: ", te_dw.view(b, -1, te_dw.size(-1)))
torch
.
testing
.
assert_close
(
te_dw
.
view
(
b
,
-
1
,
te_dw
.
size
(
-
1
)),
batched_int32_dw
,
atol
=
0
,
rtol
=
0
)
torch
.
testing
.
assert_close
(
te_dw
.
view
(
b
,
-
1
,
te_dw
.
size
(
-
1
)),
batched_int32_dw
,
atol
=
0
,
rtol
=
0
)
# NT
b
=
4
transa
=
False
transb
=
True
dy_bf16
=
[(
torch
.
randn
((
m
,
n
),
device
=
device
)).
to
(
dtype
=
torch
.
bfloat16
)
for
i
in
range
(
b
)]
x_bf16
=
[(
torch
.
randn
((
m
,
k
),
device
=
device
)).
to
(
dtype
=
torch
.
bfloat16
)
for
i
in
range
(
b
)]
dw_ref
=
[(
torch
.
randn
((
n
,
k
),
device
=
device
)).
to
(
dtype
=
torch
.
bfloat16
)
for
i
in
range
(
b
)]
dw
=
[(
torch
.
randn
((
n
,
k
),
device
=
device
)).
to
(
dtype
=
torch
.
bfloat16
)
for
i
in
range
(
b
)]
# Cast to FP8 and back
dy_fp8
=
[
to_float8_CS
(
dy_bf16
[
i
],
fp8_dtype
=
tex
.
DType
.
kFloat8E5M2
)
for
i
in
range
(
b
)]
x_fp8
=
[
to_float8_CS
(
x_bf16
[
i
],
fp8_dtype
=
tex
.
DType
.
kFloat8E5M2
)
for
i
in
range
(
b
)]
if
int8_simulation_fp8_tensorwise
:
dy_int8
,
dy_scales
=
[
dy_fp8
[
i
].
_data
.
view
(
dtype
=
torch
.
int8
)
for
i
in
range
(
b
)],
[
dy_fp8
[
i
].
_scale_inv
for
i
in
range
(
b
)]
x_int8
,
x_scales
=
[
x_fp8
[
i
].
_data
.
view
(
dtype
=
torch
.
int8
)
for
i
in
range
(
b
)],
[
x_fp8
[
i
].
_scale_inv
for
i
in
range
(
b
)]
else
:
dy_int8
,
dy_scales
=
[],
[]
x_int8
,
x_scales
=
[],
[]
assert
False
for
i
in
range
(
b
):
# cuBLAS GEMM
# return type is out, bias_grad, gelu_input, extra_output
# We are just capturing out.
dw_int32
=
tex
.
generic_gemm
(
x_int8
[
i
],
transa
,
dy_int8
[
i
],
transb
,
None
,
None
,
TE_DType
[
out_dtype
],
bias
,
bias_dtype
,
use_gelu
,
aux_tensor
,
use_grad
,
workspace
,
workspace
.
shape
[
0
],
accumulate
,
use_split_accumulator
,
)[
0
]
if
int8_simulation_fp8_tensorwise
:
tensorwise_dequantize
(
dy_scales
[
i
],
x_scales
[
i
],
dw_int32
,
dw_ref
[
i
])
else
:
assert
False
dw_ref_tensor
=
torch
.
stack
(
dw_ref
).
contiguous
()
# print("dw_ref_tensor: ", dw_ref_tensor)
torch
.
cuda
.
synchronize
()
dy_int8_tensor
=
torch
.
stack
(
dy_int8
).
contiguous
()
dy_scales_tensor
=
torch
.
stack
(
dy_scales
).
contiguous
()
x_int8_tensor
=
torch
.
stack
(
x_int8
).
contiguous
()
x_scales_tensor
=
torch
.
stack
(
x_scales
).
contiguous
()
dw_tensor
=
torch
.
stack
(
dw
).
contiguous
()
out_dtype
=
torch
.
bfloat16
dw_tensor
=
tex
.
tensorwise_int8_batchgemm
(
x_int8_tensor
.
view
(
-
1
,
x_int8
.
size
(
-
1
)),
transa
,
dy_int8_tensor
.
view
(
-
1
,
dy_int8
.
size
(
-
1
)),
transb
,
x_scales_tensor
,
dy_scales_tensor
,
dw_tensor
.
view
(
-
1
,
dw
.
size
(
-
1
)),
b
,
out_quantizer
,
TE_DType
[
out_dtype
],
bias
,
bias_dtype
,
use_gelu
,
aux_tensor
,
use_grad
,
workspace
,
workspace
.
shape
[
0
],
accumulate
,
use_split_accumulator
,
)[
0
]
# print("dw_tensor: ", dw_tensor)
torch
.
testing
.
assert_close
(
dw_ref_tensor
,
dw_tensor
,
atol
=
1e-5
,
rtol
=
1e-5
)
transformer_engine/common/gemm/cublaslt_gemm.cu
View file @
7a923605
...
@@ -684,8 +684,9 @@ void nvte_cublas_gemm(const NVTETensor A, const NVTETensor B, NVTETensor D, cons
...
@@ -684,8 +684,9 @@ void nvte_cublas_gemm(const NVTETensor A, const NVTETensor B, NVTETensor D, cons
const
char
*
NVTE_FORCE_ROCM_GEMM
=
std
::
getenv
(
"NVTE_FORCE_ROCM_GEMM"
);
const
char
*
NVTE_FORCE_ROCM_GEMM
=
std
::
getenv
(
"NVTE_FORCE_ROCM_GEMM"
);
const
bool
use_fp8
=
is_fp8_dtype
(
inputA
->
data
.
dtype
)
||
const
bool
use_fp8
=
is_fp8_dtype
(
inputA
->
data
.
dtype
)
||
is_fp8_dtype
(
inputB
->
data
.
dtype
);
is_fp8_dtype
(
inputB
->
data
.
dtype
);
const
char
*
NVTE_INT8_SIM_FP8_TENSORWISE
=
std
::
getenv
(
"NVTE_INT8_SIM_FP8_TENSORWISE"
);
if
(
NVTE_INT8_SIM_FP8_TENSORWISE
!=
nullptr
&&
NVTE_INT8_SIM_FP8_TENSORWISE
[
0
]
==
'1'
&&
use_int8
&&
use_split_accumulator
)
nvte_use_hipblaslt
=
1
;
if
((
biasTensor
->
data
.
dptr
!=
nullptr
)
||
(
outputGelu
->
data
.
dptr
!=
nullptr
)
||
(
use_fp8
)
||
(
NVTE_FORCE_ROCM_GEMM
!=
nullptr
&&
NVTE_FORCE_ROCM_GEMM
[
0
]
==
'1'
)
||
(
nvte_use_hipblaslt
)
||
(
nvte_use_rocblas
))
{
if
((
biasTensor
->
data
.
dptr
!=
nullptr
)
||
(
outputGelu
->
data
.
dptr
!=
nullptr
)
||
(
use_fp8
)
||
(
NVTE_FORCE_ROCM_GEMM
!=
nullptr
&&
NVTE_FORCE_ROCM_GEMM
[
0
]
==
'1'
)
||
(
nvte_use_hipblaslt
)
||
(
nvte_use_rocblas
))
{
NVTE_CHECK
(
!
use_int8
,
"Int8 gemm just surpport pure int8 gemm without any epilogue."
);
cublas_gemm
(
inputA
,
inputB
,
outputD
,
biasTensor
,
outputGelu
,
m
,
n
,
k
,
lda
,
ldb
,
ldd
,
transa
,
transb
,
grad
,
cublas_gemm
(
inputA
,
inputB
,
outputD
,
biasTensor
,
outputGelu
,
m
,
n
,
k
,
lda
,
ldb
,
ldd
,
transa
,
transb
,
grad
,
wspace
->
data
.
dptr
,
wspace
->
data
.
shape
[
0
],
accumulate
,
use_split_accumulator
,
math_sm_count
,
0
,
0
,
wspace
->
data
.
dptr
,
wspace
->
data
.
shape
[
0
],
accumulate
,
use_split_accumulator
,
math_sm_count
,
0
,
0
,
false
,
nullptr
,
stream
,
nvte_use_hipblaslt
,
nvte_use_rocblas
,
compute_stream_offset
);
false
,
nullptr
,
stream
,
nvte_use_hipblaslt
,
nvte_use_rocblas
,
compute_stream_offset
);
...
@@ -1100,4 +1101,78 @@ void nvte_cublas_batchgemm_v2(const NVTETensor A, const NVTETensor B, NVTETensor
...
@@ -1100,4 +1101,78 @@ void nvte_cublas_batchgemm_v2(const NVTETensor A, const NVTETensor B, NVTETensor
batch_count
,
batch_count
,
stream
);
stream
);
}
}
// add for batchgemm
void
nvte_cublas_batchgemm_v3
(
const
NVTETensor
A
,
const
NVTETensor
B
,
const
NVTETensor
A_scales
,
const
NVTETensor
B_scales
,
NVTETensor
D
,
const
NVTETensor
bias
,
NVTETensor
pre_gelu_out
,
bool
transa
,
bool
transb
,
bool
grad
,
NVTETensor
workspace
,
bool
accumulate
,
bool
use_split_accumulator
,
int
math_sm_count
,
int
batch_count
,
cudaStream_t
stream
)
{
NVTE_API_CALL
(
nvte_cublas_batchgemm_v3
);
using
namespace
transformer_engine
;
const
Tensor
*
inputA
=
convertNVTETensorCheck
(
A
);
const
Tensor
*
inputB
=
convertNVTETensorCheck
(
B
);
const
Tensor
*
inputA_scales
=
convertNVTETensorCheck
(
A_scales
);
const
Tensor
*
inputB_scales
=
convertNVTETensorCheck
(
B_scales
);
Tensor
*
outputD
=
convertNVTETensor
(
D
);
const
Tensor
*
biasTensor
=
convertNVTETensor
(
bias
);
Tensor
*
outputGelu
=
convertNVTETensor
(
pre_gelu_out
);
Tensor
*
wspace
=
convertNVTETensor
(
workspace
);
if
((
biasTensor
->
data
.
dptr
!=
nullptr
)
||
(
outputGelu
->
data
.
dptr
!=
nullptr
))
{
NVTE_ERROR
(
"MOE batchgemm not surpport bias or gelu."
);
}
int
m
,
n
,
k
;
if
(
!
transa
&&
transb
)
{
// for NT
m
=
transa
?
inputA
->
data
.
shape
[
0
]
/
batch_count
:
inputA
->
data
.
shape
[
1
];
k
=
transa
?
inputA
->
data
.
shape
[
1
]
:
inputA
->
data
.
shape
[
0
]
/
batch_count
;
n
=
transb
?
inputB
->
data
.
shape
[
1
]
:
inputB
->
data
.
shape
[
0
]
/
batch_count
;
}
else
if
(
transa
&&
!
transb
){
// for TN
m
=
transa
?
inputA
->
data
.
shape
[
0
]
/
batch_count
:
inputA
->
data
.
shape
[
1
];
k
=
transa
?
inputA
->
data
.
shape
[
1
]
:
inputA
->
data
.
shape
[
0
]
/
batch_count
;
n
=
transb
?
inputB
->
data
.
shape
[
1
]
:
inputB
->
data
.
shape
[
0
]
/
batch_count
;
}
else
if
(
!
transa
&&
!
transb
){
// for NN
m
=
transa
?
inputA
->
data
.
shape
[
0
]
/
batch_count
:
inputA
->
data
.
shape
[
1
];
k
=
transa
?
inputA
->
data
.
shape
[
1
]
:
inputA
->
data
.
shape
[
0
]
/
batch_count
;
n
=
transb
?
inputB
->
data
.
shape
[
1
]
:
inputB
->
data
.
shape
[
0
]
/
batch_count
;
}
int
lda
,
ldb
,
ldd
;
if
(
transa
&&
!
transb
)
{
// TN
lda
=
k
;
ldb
=
k
;
ldd
=
m
;
}
else
if
(
!
transa
&&
!
transb
)
{
// NN
lda
=
m
;
ldb
=
k
;
ldd
=
m
;
}
else
if
(
!
transa
&&
transb
)
{
// NT
lda
=
m
;
ldb
=
n
;
ldd
=
m
;
}
else
{
// TT
NVTE_ERROR
(
"TT layout not allowed."
);
}
hipblasLtHandle_t
handle
=
nullptr
;
// Init hipblaslt handles (once, globally)
static
std
::
once_flag
init_flag
;
static
hipblasLtHandle_t
hipblaslt_handles
[
compute_num_streams
];
std
::
call_once
(
init_flag
,
init_hipblaslt_handles
,
hipblaslt_handles
);
handle
=
hipblaslt_handles
[
0
];
hipblaslt_batchgemm_tensorwise_int8
(
inputA
,
inputB
,
inputA_scales
,
inputB_scales
,
outputD
,
biasTensor
,
outputGelu
,
m
,
n
,
k
,
lda
,
ldb
,
ldd
,
(
transa
)
?
HIPBLAS_OP_T
:
HIPBLAS_OP_N
,
(
transb
)
?
HIPBLAS_OP_T
:
HIPBLAS_OP_N
,
grad
,
wspace
->
data
.
dptr
,
wspace
->
data
.
shape
[
0
],
accumulate
,
use_split_accumulator
,
math_sm_count
,
0
,
0
,
false
,
nullptr
,
batch_count
,
stream
,
handle
);
}
#endif
#endif
\ No newline at end of file
transformer_engine/common/gemm/rocm_gemm.cu
View file @
7a923605
This diff is collapsed.
Click to expand it.
transformer_engine/common/include/transformer_engine/gemm.h
View file @
7a923605
...
@@ -134,6 +134,11 @@ void nvte_cublas_batchgemm_v2(const NVTETensor A, const NVTETensor B, NVTETensor
...
@@ -134,6 +134,11 @@ void nvte_cublas_batchgemm_v2(const NVTETensor A, const NVTETensor B, NVTETensor
NVTETensor
pre_gelu_out
,
bool
transa
,
bool
transb
,
bool
grad
,
NVTETensor
pre_gelu_out
,
bool
transa
,
bool
transb
,
bool
grad
,
NVTETensor
workspace
,
bool
accumulate
,
bool
use_split_accumulator
,
NVTETensor
workspace
,
bool
accumulate
,
bool
use_split_accumulator
,
int
math_sm_count
,
int
batch_count
,
cudaStream_t
stream
);
int
math_sm_count
,
int
batch_count
,
cudaStream_t
stream
);
void
nvte_cublas_batchgemm_v3
(
const
NVTETensor
A
,
const
NVTETensor
B
,
const
NVTETensor
A_scales
,
const
NVTETensor
B_scales
,
NVTETensor
D
,
const
NVTETensor
bias
,
NVTETensor
pre_gelu_out
,
bool
transa
,
bool
transb
,
bool
grad
,
NVTETensor
workspace
,
bool
accumulate
,
bool
use_split_accumulator
,
int
math_sm_count
,
int
batch_count
,
cudaStream_t
stream
);
#endif
#endif
#ifdef __cplusplus
#ifdef __cplusplus
...
...
transformer_engine/common/transformer_engine.cpp
View file @
7a923605
...
@@ -131,7 +131,7 @@ void CheckScaleTensorShape(const Tensor &t, const std::string &name) {
...
@@ -131,7 +131,7 @@ void CheckScaleTensorShape(const Tensor &t, const std::string &name) {
void
CheckInputTensor
(
const
Tensor
&
t
,
const
std
::
string
&
name
)
{
void
CheckInputTensor
(
const
Tensor
&
t
,
const
std
::
string
&
name
)
{
const
DType
type
=
t
.
dtype
();
const
DType
type
=
t
.
dtype
();
if
(
is_fp8_dtype
(
type
))
{
if
(
is_fp8_dtype
(
type
)
||
is_int8_dtype
(
type
)
)
{
// FP8 input needs to have scale_inv
// FP8 input needs to have scale_inv
if
(
t
.
has_data
())
{
if
(
t
.
has_data
())
{
NVTE_CHECK
(
t
.
scale_inv
.
dptr
!=
nullptr
,
"FP8 scaling factor input "
,
name
,
NVTE_CHECK
(
t
.
scale_inv
.
dptr
!=
nullptr
,
"FP8 scaling factor input "
,
name
,
...
...
transformer_engine/pytorch/cpp_extensions/gemm.py
View file @
7a923605
...
@@ -38,6 +38,7 @@ from transformer_engine.pytorch.triton.per_token_group_quant import (per_token_q
...
@@ -38,6 +38,7 @@ from transformer_engine.pytorch.triton.per_token_group_quant import (per_token_q
from
transformer_engine.pytorch.utils
import
get_device_compute_capability
from
transformer_engine.pytorch.utils
import
get_device_compute_capability
from
transformer_engine.pytorch.fp8
import
int8_simulation_fp8
,
int8_simulation_fp8_tensorwise
from
transformer_engine.pytorch.fp8
import
int8_simulation_fp8
,
int8_simulation_fp8_tensorwise
tensorwise_int8_check
=
bool
(
int
(
os
.
getenv
(
"NVTE_INT8_SIM_FP8_TENSORWISE_CHECK"
,
"0"
)))
__all__
=
[
__all__
=
[
"general_gemm"
,
"general_gemm"
,
"general_grouped_gemm"
,
"general_grouped_gemm"
,
...
@@ -181,6 +182,47 @@ def general_gemm(
...
@@ -181,6 +182,47 @@ def general_gemm(
):
):
raise
RuntimeError
(
"GEMM with Float8BlockwiseQTensor requires GEMM_READY format"
)
raise
RuntimeError
(
"GEMM with Float8BlockwiseQTensor requires GEMM_READY format"
)
if
int8_simulation_fp8
and
(
isinstance
(
A
,
Float8TensorBase
)
or
isinstance
(
B
,
Float8TensorBase
))
and
int8_simulation_fp8_tensorwise
:
assert
not
gelu
,
"GELU not supported with int8 simulation"
assert
gelu_in
is
None
,
"GELU input not supported with int8 simulation"
assert
bias
is
None
,
"Bias not supported with int8 simulation"
assert
ub
is
None
,
"User buffer not supported with int8 simulation"
assert
ub_type
is
None
,
"User buffer type not supported with int8 simulation"
assert
extra_output
is
None
,
"Extra output not supported with int8 simulation"
assert
not
bulk_overlap
,
"Bulk overlap not supported with int8 simulation"
assert
out_dtype
is
torch
.
bfloat16
or
out_dtype
is
torch
.
float32
,
"Out_dtype must be bfloat16 or float32 for int8 simulation"
args
=
(
A
,
transa
,
# transa
B
,
transb
,
# transb
out
,
quantization_params
,
TE_DType
[
out_dtype
]
if
out_dtype
is
not
None
else
None
,
bias
,
bias_dtype
,
gelu
,
gelu_in
,
grad
,
# grad
workspace
,
workspace
.
shape
[
0
],
accumulate
,
True
,
)
kwargs
=
{
"comm_overlap"
:
ub
,
"comm_type"
:
ub_type
,
"extra_output"
:
extra_output
,
"bulk_overlap"
:
bulk_overlap
,
}
out
,
bias_grad
,
gelu_input
,
extra_output
=
tex
.
generic_gemm
(
*
args
,
**
kwargs
)
if
debug_quantizer
is
not
None
:
out
=
debug_quantizer
.
process_gemm_output
(
out
)
return
out
,
bias_grad
,
gelu_input
,
extra_output
if
int8_simulation_fp8
and
(
isinstance
(
A
,
Float8TensorBase
)
or
isinstance
(
B
,
Float8TensorBase
)):
if
int8_simulation_fp8
and
(
isinstance
(
A
,
Float8TensorBase
)
or
isinstance
(
B
,
Float8TensorBase
)):
assert
not
gelu
,
"GELU not supported with int8 simulation"
assert
not
gelu
,
"GELU not supported with int8 simulation"
...
@@ -195,10 +237,6 @@ def general_gemm(
...
@@ -195,10 +237,6 @@ def general_gemm(
if
layout
==
"TN"
:
if
layout
==
"TN"
:
assert
out_dtype
is
torch
.
bfloat16
assert
out_dtype
is
torch
.
bfloat16
out_shape
=
B
.
_data
.
shape
[:
-
1
]
+
(
A
.
_data
.
shape
[
0
],
)
out_shape
=
B
.
_data
.
shape
[:
-
1
]
+
(
A
.
_data
.
shape
[
0
],
)
if
int8_simulation_fp8_tensorwise
:
x_int8
,
x_scales
=
B
.
_data
.
view
(
dtype
=
TE_DType_To_Torch
[
B
.
_fp8_dtype
]),
B
.
_scale_inv
w_int8
,
w_scales
=
A
.
_data
.
view
(
dtype
=
TE_DType_To_Torch
[
A
.
_fp8_dtype
]),
A
.
_scale_inv
else
:
x_int8
,
x_scales
=
per_token_quant_fp8_to_int8
(
B
.
_data
.
view
(
dtype
=
TE_DType_To_Torch
[
B
.
_fp8_dtype
]),
B
.
_scale_inv
,
False
)
x_int8
,
x_scales
=
per_token_quant_fp8_to_int8
(
B
.
_data
.
view
(
dtype
=
TE_DType_To_Torch
[
B
.
_fp8_dtype
]),
B
.
_scale_inv
,
False
)
w_int8
,
w_scales
=
per_token_quant_fp8_to_int8
(
A
.
_data
.
view
(
dtype
=
TE_DType_To_Torch
[
A
.
_fp8_dtype
]),
A
.
_scale_inv
,
False
)
w_int8
,
w_scales
=
per_token_quant_fp8_to_int8
(
A
.
_data
.
view
(
dtype
=
TE_DType_To_Torch
[
A
.
_fp8_dtype
]),
A
.
_scale_inv
,
False
)
...
@@ -220,20 +258,12 @@ def general_gemm(
...
@@ -220,20 +258,12 @@ def general_gemm(
False
,
False
,
use_split_accumulator
,
use_split_accumulator
,
)[
0
]
)[
0
]
if
int8_simulation_fp8_tensorwise
:
y
=
torch
.
empty_like
(
y_int32
,
device
=
y_int32
.
device
,
dtype
=
torch
.
bfloat16
)
tensorwise_dequantize
(
x_scales
,
w_scales
,
y_int32
,
y
)
else
:
y
=
channelwise_dequantize_transB
(
x_scales
,
w_scales
,
y_int32
)
y
=
channelwise_dequantize_transB
(
x_scales
,
w_scales
,
y_int32
)
return
y
.
view
(
out_shape
),
None
,
None
,
None
return
y
.
view
(
out_shape
),
None
,
None
,
None
elif
layout
==
"NN"
:
elif
layout
==
"NN"
:
assert
out_dtype
is
torch
.
bfloat16
assert
out_dtype
is
torch
.
bfloat16
dx_shape
=
B
.
_data
.
shape
[:
-
1
]
+
(
A
.
_data
.
shape
[
-
1
],
)
dx_shape
=
B
.
_data
.
shape
[:
-
1
]
+
(
A
.
_data
.
shape
[
-
1
],
)
if
int8_simulation_fp8_tensorwise
:
dy_int8
,
dy_scales
=
B
.
_data
.
view
(
dtype
=
TE_DType_To_Torch
[
B
.
_fp8_dtype
]),
B
.
_scale_inv
w_int8
,
w_scales
=
A
.
_data
.
view
(
dtype
=
TE_DType_To_Torch
[
A
.
_fp8_dtype
]),
A
.
_scale_inv
else
:
dy_int8
,
dy_scales
=
per_token_quant_fp8_to_int8
(
B
.
_data
.
view
(
dtype
=
TE_DType_To_Torch
[
B
.
_fp8_dtype
]),
B
.
_scale_inv
,
False
)
dy_int8
,
dy_scales
=
per_token_quant_fp8_to_int8
(
B
.
_data
.
view
(
dtype
=
TE_DType_To_Torch
[
B
.
_fp8_dtype
]),
B
.
_scale_inv
,
False
)
w_int8
,
w_scales
=
per_token_quant_fp8_to_int8_opt
(
A
.
_data
.
view
(
dtype
=
TE_DType_To_Torch
[
A
.
_fp8_dtype
]),
A
.
_scale_inv
,
False
)
w_int8
,
w_scales
=
per_token_quant_fp8_to_int8_opt
(
A
.
_data
.
view
(
dtype
=
TE_DType_To_Torch
[
A
.
_fp8_dtype
]),
A
.
_scale_inv
,
False
)
...
@@ -255,19 +285,11 @@ def general_gemm(
...
@@ -255,19 +285,11 @@ def general_gemm(
False
,
False
,
use_split_accumulator
,
use_split_accumulator
,
)[
0
]
)[
0
]
if
int8_simulation_fp8_tensorwise
:
dx
=
torch
.
empty_like
(
dx_int32
,
device
=
dx_int32
.
device
,
dtype
=
torch
.
bfloat16
)
tensorwise_dequantize
(
dy_scales
,
w_scales
,
dx_int32
,
dx
)
else
:
dx
=
channelwise_dequantize
(
dy_scales
,
w_scales
,
dx_int32
)
dx
=
channelwise_dequantize
(
dy_scales
,
w_scales
,
dx_int32
)
return
dx
.
view
(
dx_shape
),
None
,
None
,
None
return
dx
.
view
(
dx_shape
),
None
,
None
,
None
elif
layout
==
"NT"
:
elif
layout
==
"NT"
:
assert
out_dtype
is
torch
.
bfloat16
or
out_dtype
is
torch
.
float32
assert
out_dtype
is
torch
.
bfloat16
or
out_dtype
is
torch
.
float32
if
int8_simulation_fp8_tensorwise
:
dy_int8
,
dy_scales
=
B
.
_data
.
view
(
dtype
=
TE_DType_To_Torch
[
B
.
_fp8_dtype
]),
B
.
_scale_inv
x_int8
,
x_scales
=
A
.
_data
.
view
(
dtype
=
TE_DType_To_Torch
[
A
.
_fp8_dtype
]),
A
.
_scale_inv
else
:
dy_int8
,
dy_scales
=
per_token_quant_fp8_to_int8_opt
(
B
.
_data
.
view
(
dtype
=
TE_DType_To_Torch
[
B
.
_fp8_dtype
]),
B
.
_scale_inv
,
False
)
dy_int8
,
dy_scales
=
per_token_quant_fp8_to_int8_opt
(
B
.
_data
.
view
(
dtype
=
TE_DType_To_Torch
[
B
.
_fp8_dtype
]),
B
.
_scale_inv
,
False
)
x_int8
,
x_scales
=
per_token_quant_fp8_to_int8_opt
(
A
.
_data
.
view
(
dtype
=
TE_DType_To_Torch
[
A
.
_fp8_dtype
]),
A
.
_scale_inv
,
False
)
x_int8
,
x_scales
=
per_token_quant_fp8_to_int8_opt
(
A
.
_data
.
view
(
dtype
=
TE_DType_To_Torch
[
A
.
_fp8_dtype
]),
A
.
_scale_inv
,
False
)
...
@@ -291,27 +313,13 @@ def general_gemm(
...
@@ -291,27 +313,13 @@ def general_gemm(
)[
0
]
)[
0
]
if
out_dtype
is
torch
.
bfloat16
:
if
out_dtype
is
torch
.
bfloat16
:
if
accumulate
:
if
accumulate
:
if
int8_simulation_fp8_tensorwise
:
tensorwise_dequantize_add
(
dy_scales
,
x_scales
,
dw_int32
,
out
)
else
:
channelwise_dequantize_transA_add
(
dy_scales
,
x_scales
,
dw_int32
,
out
)
channelwise_dequantize_transA_add
(
dy_scales
,
x_scales
,
dw_int32
,
out
)
else
:
if
int8_simulation_fp8_tensorwise
:
out
=
torch
.
empty_like
(
dw_int32
,
device
=
dw_int32
.
device
,
dtype
=
torch
.
bfloat16
)
tensorwise_dequantize
(
dy_scales
,
x_scales
,
dw_int32
,
out
)
else
:
else
:
out
=
channelwise_dequantize_transA
(
dy_scales
,
x_scales
,
dw_int32
)
out
=
channelwise_dequantize_transA
(
dy_scales
,
x_scales
,
dw_int32
)
else
:
else
:
if
accumulate
:
if
accumulate
:
if
int8_simulation_fp8_tensorwise
:
tensorwise_dequantize_float_add
(
dy_scales
,
x_scales
,
dw_int32
,
out
)
else
:
channelwise_dequantize_transA_float_add
(
dy_scales
,
x_scales
,
dw_int32
,
out
)
channelwise_dequantize_transA_float_add
(
dy_scales
,
x_scales
,
dw_int32
,
out
)
else
:
if
int8_simulation_fp8_tensorwise
:
out
=
torch
.
empty_like
(
dw_int32
,
device
=
dw_int32
.
device
,
dtype
=
torch
.
float32
)
tensorwise_dequantize_float
(
dy_scales
,
x_scales
,
dw_int32
,
out
)
else
:
else
:
out
=
channelwise_dequantize_transA_float
(
dy_scales
,
x_scales
,
dw_int32
)
out
=
channelwise_dequantize_transA_float
(
dy_scales
,
x_scales
,
dw_int32
)
return
out
,
None
,
None
,
None
return
out
,
None
,
None
,
None
...
@@ -465,6 +473,120 @@ def general_grouped_gemm(
...
@@ -465,6 +473,120 @@ def general_grouped_gemm(
else
:
else
:
raise
ValueError
(
f
"Unsupported layout
{
layout
}
in int8 simulation fp8"
)
raise
ValueError
(
f
"Unsupported layout
{
layout
}
in int8 simulation fp8"
)
if
int8_simulation_fp8
and
(
isinstance
(
A
,
Float8TensorBase
)
or
isinstance
(
B
,
Float8TensorBase
))
and
int8_simulation_fp8_tensorwise
:
assert
len
(
set
(
m_splits
))
==
1
,
"Int8 simulation groupgemm just surpport token pad as same as batchgemm for now."
assert
not
gelu
,
"GELU not supported with int8 simulation groupgemm."
assert
not
use_bias
,
"Bias not supported with int8 simulation groupgemm."
assert
out_dtype
is
torch
.
bfloat16
or
out_dtype
is
torch
.
float32
,
"Out_dtype must be bfloat16 or float32 for int8 simulation"
if
layout
==
"TN"
:
assert
out_dtype
is
torch
.
bfloat16
qx_data_list
,
scales_x_list
=
[
b
.
_data
.
view
(
dtype
=
TE_DType_To_Torch
[
b
.
_fp8_dtype
])
for
b
in
B
],
[
b
.
_scale_inv
for
b
in
B
]
w_data_list
,
scales_w_list
=
[
a
.
_data
.
view
(
dtype
=
TE_DType_To_Torch
[
a
.
_fp8_dtype
])
for
a
in
A
],
[
a
.
_scale_inv
for
a
in
A
]
num_gemms
=
len
(
A
)
qx_data
=
torch
.
stack
(
qx_data_list
).
contiguous
()
w_data
=
torch
.
stack
(
w_data_list
).
contiguous
()
scales_x
=
torch
.
stack
(
scales_x_list
).
contiguous
()
scales_w
=
torch
.
stack
(
scales_w_list
).
contiguous
()
out
[
0
]
=
tex
.
tensorwise_int8_batchgemm
(
w_data
.
view
(
-
1
,
w_data
.
size
(
-
1
)),
transa
,
qx_data
.
view
(
-
1
,
qx_data
.
size
(
-
1
)),
transb
,
scales_w
,
scales_x
,
out
[
0
],
num_gemms
,
None
,
TE_DType
[
out_dtype
],
bias
,
bias_dtype
,
gelu
,
gelu_input
[
0
],
grad
,
# grad
workspaces
[
0
],
workspaces
[
0
].
shape
[
0
],
accumulate
,
use_split_accumulator
,
)[
0
]
return
out
,
bias
,
gelu_input
if
layout
==
"NN"
:
assert
out_dtype
is
torch
.
bfloat16
qdout_data_list
,
scales_dout_list
=
[
b
.
_data
.
view
(
dtype
=
TE_DType_To_Torch
[
b
.
_fp8_dtype
])
for
b
in
B
],
[
b
.
_scale_inv
for
b
in
B
]
w_data_list
,
scales_w_list
=
[
a
.
_data
.
view
(
dtype
=
TE_DType_To_Torch
[
a
.
_fp8_dtype
])
for
a
in
A
],
[
a
.
_scale_inv
for
a
in
A
]
num_gemms
=
len
(
A
)
qdout_data
=
torch
.
stack
(
qdout_data_list
).
contiguous
()
w_data
=
torch
.
stack
(
w_data_list
).
contiguous
()
scales_dout
=
torch
.
stack
(
scales_dout_list
).
contiguous
()
scales_w
=
torch
.
stack
(
scales_w_list
).
contiguous
()
out
[
0
]
=
tex
.
tensorwise_int8_batchgemm
(
w_data
.
view
(
-
1
,
w_data
.
size
(
-
1
)),
transa
,
qdout_data
.
view
(
-
1
,
qdout_data
.
size
(
-
1
)),
transb
,
scales_w
,
scales_dout
,
out
[
0
],
num_gemms
,
None
,
TE_DType
[
out_dtype
],
bias
,
bias_dtype
,
gelu
,
gelu_input
[
0
],
grad
,
# grad
workspaces
[
0
],
workspaces
[
0
].
shape
[
0
],
accumulate
,
use_split_accumulator
,
)[
0
]
return
out
,
bias
,
gelu_input
elif
layout
==
"NT"
:
assert
out_dtype
is
torch
.
bfloat16
or
out_dtype
is
torch
.
float32
qdout_data_list
,
scales_dout_list
=
[
b
.
_data
.
view
(
dtype
=
TE_DType_To_Torch
[
b
.
_fp8_dtype
])
for
b
in
B
],
[
b
.
_scale_inv
for
b
in
B
]
qx_data_list
,
scales_x_list
=
[
a
.
_data
.
view
(
dtype
=
TE_DType_To_Torch
[
a
.
_fp8_dtype
])
for
a
in
A
],
[
a
.
_scale_inv
for
a
in
A
]
num_gemms
=
len
(
A
)
qdout_data
=
torch
.
stack
(
qdout_data_list
).
contiguous
()
qx_data
=
torch
.
stack
(
qx_data_list
).
contiguous
()
scales_dout
=
torch
.
stack
(
scales_dout_list
).
contiguous
()
scales_x
=
torch
.
stack
(
scales_x_list
).
contiguous
()
dw
=
torch
.
stack
(
out
).
contiguous
()
dw
=
tex
.
tensorwise_int8_batchgemm
(
qx_data
.
view
(
-
1
,
qx_data
.
size
(
-
1
)),
transa
,
qdout_data
.
view
(
-
1
,
qdout_data
.
size
(
-
1
)),
transb
,
scales_x
,
scales_dout
,
dw
.
view
(
-
1
,
dw
.
size
(
-
1
)),
num_gemms
,
None
,
TE_DType
[
out_dtype
],
bias
,
bias_dtype
,
gelu
,
gelu_input
[
0
],
grad
,
# grad
workspaces
[
0
],
workspaces
[
0
].
shape
[
0
],
accumulate
,
use_split_accumulator
,
)
for
i
in
range
(
num_gemms
):
out
[
i
].
copy_
(
dw
[
i
])
return
out
,
bias
,
gelu_input
if
int8_simulation_fp8
and
(
isinstance
(
A
[
0
],
Float8TensorBase
)
or
isinstance
(
B
[
0
],
Float8TensorBase
)):
if
int8_simulation_fp8
and
(
isinstance
(
A
[
0
],
Float8TensorBase
)
or
isinstance
(
B
[
0
],
Float8TensorBase
)):
assert
len
(
set
(
m_splits
))
==
1
,
"Int8 simulation groupgemm just surpport token pad as same as batchgemm for now."
assert
len
(
set
(
m_splits
))
==
1
,
"Int8 simulation groupgemm just surpport token pad as same as batchgemm for now."
assert
not
gelu
,
"GELU not supported with int8 simulation groupgemm."
assert
not
gelu
,
"GELU not supported with int8 simulation groupgemm."
...
...
transformer_engine/pytorch/csrc/extensions.h
View file @
7a923605
...
@@ -88,14 +88,6 @@ std::vector<py::object> gemm(py::handle A, bool transa, py::handle B, bool trans
...
@@ -88,14 +88,6 @@ std::vector<py::object> gemm(py::handle A, bool transa, py::handle B, bool trans
std
::
optional
<
CommOverlapType
>
comm_type
=
std
::
nullopt
,
std
::
optional
<
CommOverlapType
>
comm_type
=
std
::
nullopt
,
MaybeTensor
extra_output
=
std
::
nullopt
,
bool
bulk_overlap
=
false
);
MaybeTensor
extra_output
=
std
::
nullopt
,
bool
bulk_overlap
=
false
);
std
::
vector
<
py
::
object
>
generic_batchgemm
(
py
::
handle
A
,
bool
transa
,
py
::
handle
B
,
bool
transb
,
py
::
object
D
,
int
batch_count
,
py
::
handle
quantizer
,
std
::
optional
<
DType
>
out_dtype
,
MaybeTensor
bias
,
DType
bias_type
,
bool
gelu
,
MaybeTensor
gelu_in
,
bool
grad
,
at
::
Tensor
workspace
,
size_t
workspaceSize
,
bool
accumulate
,
bool
use_split_accumulator
,
CommOverlapCore
*
comm_overlap
=
nullptr
,
std
::
optional
<
CommOverlapType
>
comm_type
=
std
::
nullopt
,
MaybeTensor
extra_output
=
std
::
nullopt
,
bool
bulk_overlap
=
false
);
void
te_atomic_gemm
(
at
::
Tensor
A
,
at
::
Tensor
A_scale_inverse
,
DType
A_type
,
void
te_atomic_gemm
(
at
::
Tensor
A
,
at
::
Tensor
A_scale_inverse
,
DType
A_type
,
std
::
vector
<
int64_t
>
A_scaling_mode
,
bool
transa
,
at
::
Tensor
B
,
std
::
vector
<
int64_t
>
A_scaling_mode
,
bool
transa
,
at
::
Tensor
B
,
at
::
Tensor
B_scale_inverse
,
DType
B_type
,
std
::
vector
<
int64_t
>
B_scaling_mode
,
at
::
Tensor
B_scale_inverse
,
DType
B_type
,
std
::
vector
<
int64_t
>
B_scaling_mode
,
...
@@ -113,6 +105,22 @@ std::optional<std::vector<at::Tensor>> te_general_grouped_gemm(
...
@@ -113,6 +105,22 @@ std::optional<std::vector<at::Tensor>> te_general_grouped_gemm(
size_t
workspaceSize
,
bool
accumulate
,
bool
use_split_accumulator
,
int
math_sm_count
);
size_t
workspaceSize
,
bool
accumulate
,
bool
use_split_accumulator
,
int
math_sm_count
);
#ifdef __HIP_PLATFORM_AMD__
#ifdef __HIP_PLATFORM_AMD__
std
::
vector
<
py
::
object
>
generic_batchgemm
(
py
::
handle
A
,
bool
transa
,
py
::
handle
B
,
bool
transb
,
py
::
object
D
,
int
batch_count
,
py
::
handle
quantizer
,
std
::
optional
<
DType
>
out_dtype
,
MaybeTensor
bias
,
DType
bias_type
,
bool
gelu
,
MaybeTensor
gelu_in
,
bool
grad
,
at
::
Tensor
workspace
,
size_t
workspaceSize
,
bool
accumulate
,
bool
use_split_accumulator
,
CommOverlapCore
*
comm_overlap
=
nullptr
,
std
::
optional
<
CommOverlapType
>
comm_type
=
std
::
nullopt
,
MaybeTensor
extra_output
=
std
::
nullopt
,
bool
bulk_overlap
=
false
);
std
::
vector
<
py
::
object
>
tensorwise_int8_batchgemm
(
py
::
handle
A
,
bool
transa
,
py
::
handle
B
,
bool
transb
,
py
::
handle
A_scales
,
py
::
handle
B_scales
,
py
::
object
D
,
int
batch_count
,
py
::
handle
quantizer
,
std
::
optional
<
DType
>
out_dtype
,
MaybeTensor
bias
,
DType
bias_type
,
bool
gelu
,
MaybeTensor
gelu_in
,
bool
grad
,
at
::
Tensor
workspace
,
size_t
workspaceSize
,
bool
accumulate
,
bool
use_split_accumulator
,
CommOverlapCore
*
comm_overlap
=
nullptr
,
std
::
optional
<
CommOverlapType
>
comm_type
=
std
::
nullopt
,
MaybeTensor
extra_output
=
std
::
nullopt
,
bool
bulk_overlap
=
false
);
void
te_batchgemm
(
std
::
vector
<
at
::
Tensor
>
A
,
at
::
Tensor
A_scale_inverse
,
int
A_offset
,
void
te_batchgemm
(
std
::
vector
<
at
::
Tensor
>
A
,
at
::
Tensor
A_scale_inverse
,
int
A_offset
,
transformer_engine
::
DType
A_type
,
bool
transa
,
std
::
vector
<
at
::
Tensor
>
B
,
transformer_engine
::
DType
A_type
,
bool
transa
,
std
::
vector
<
at
::
Tensor
>
B
,
at
::
Tensor
B_scale_inverse
,
int
B_offset
,
transformer_engine
::
DType
B_type
,
at
::
Tensor
B_scale_inverse
,
int
B_offset
,
transformer_engine
::
DType
B_type
,
...
...
transformer_engine/pytorch/csrc/extensions/gemm.cpp
View file @
7a923605
...
@@ -271,138 +271,6 @@ std::vector<py::object> gemm(py::handle A, bool transa, py::handle B, bool trans
...
@@ -271,138 +271,6 @@ std::vector<py::object> gemm(py::handle A, bool transa, py::handle B, bool trans
return
out
;
return
out
;
}
}
std
::
vector
<
py
::
object
>
generic_batchgemm
(
py
::
handle
A
,
bool
transa
,
py
::
handle
B
,
bool
transb
,
py
::
object
D
,
int
batch_count
,
py
::
handle
quantizer
,
std
::
optional
<
DType
>
out_dtype
,
MaybeTensor
bias
,
DType
bias_type
,
bool
gelu
,
MaybeTensor
gelu_in
,
bool
grad
,
at
::
Tensor
workspace
,
size_t
workspaceSize
,
bool
accumulate
,
bool
use_split_accumulator
,
CommOverlapCore
*
comm_overlap
,
std
::
optional
<
CommOverlapType
>
comm_type
,
MaybeTensor
extra_output
,
bool
bulk_overlap
)
{
// Input tensors
NVTE_CHECK
(
!
A
.
is_none
(),
"Tensor A has not been provided"
);
NVTE_CHECK
(
!
B
.
is_none
(),
"Tensor B has not been provided"
);
auto
none
=
py
::
none
();
TensorWrapper
A_tensor
=
makeTransformerEngineTensor
(
A
,
none
);
TensorWrapper
B_tensor
=
makeTransformerEngineTensor
(
B
,
none
);
const
bool
low_precision
=
detail
::
is_low_precision
(
A_tensor
.
dtype
())
||
detail
::
is_low_precision
(
B_tensor
.
dtype
());
// Check tensor dimensions
const
auto
&
A_shape
=
A_tensor
.
shape
();
const
auto
&
B_shape
=
B_tensor
.
shape
();
const
auto
&
D_shape
=
detail
::
getGemmOutputShape
(
A_shape
,
transa
,
B_shape
,
transb
);
NVTE_CHECK
(
A_shape
.
ndim
>=
1
,
"Tensor A needs to have at least 1 dimension"
);
NVTE_CHECK
(
B_shape
.
ndim
>=
1
,
"Tensor B needs to have at least 1 dimension"
);
// Output tensor
TensorWrapper
D_tensor
;
if
(
D
.
is_none
())
{
NVTE_ERROR
(
"generic batchgemm D must be not None."
);
}
else
{
D_tensor
=
makeTransformerEngineTensor
(
D
,
quantizer
);
if
(
out_dtype
)
{
NVTE_CHECK
(
*
out_dtype
==
D_tensor
.
dtype
(),
"GEMM output has invalid dtype (expected "
,
static_cast
<
int
>
(
*
out_dtype
),
", found "
,
static_cast
<
int
>
(
D_tensor
.
dtype
()),
")"
);
}
}
// Bias tensor
TensorWrapper
bias_tensor
;
MaybeTensor
bias_grad
=
std
::
nullopt
;
if
(
bias
.
has_value
())
{
if
(
grad
)
{
auto
opts
=
torch
::
TensorOptions
().
dtype
(
GetATenDType
(
D_tensor
.
dtype
())).
device
(
torch
::
kCUDA
);
bias_grad
=
at
::
empty
({
static_cast
<
int64_t
>
(
B_shape
.
data
[
B_shape
.
ndim
-
1
])},
opts
);
bias_tensor
=
makeTransformerEngineTensor
(
*
bias_grad
);
}
else
{
if
(
!
bias
->
is_contiguous
())
{
bias
=
bias
->
contiguous
();
}
bias_tensor
=
makeTransformerEngineTensor
(
*
bias
);
}
}
// Activation input tensor
MaybeTensor
pre_gelu_out
=
std
::
nullopt
;
DType
gelu_type
=
low_precision
?
bias_type
:
D_tensor
.
dtype
();
if
(
gelu
)
{
if
(
!
grad
)
{
auto
dtype
=
GetATenDType
(
gelu_type
);
auto
opts
=
torch
::
TensorOptions
().
dtype
(
dtype
).
device
(
torch
::
kCUDA
);
std
::
vector
<
int64_t
>
torch_shape
;
for
(
auto
v
:
D_shape
)
{
torch_shape
.
push_back
(
v
);
}
pre_gelu_out
=
at
::
empty
(
torch_shape
,
opts
);
}
else
{
if
(
gelu_in
.
has_value
())
{
pre_gelu_out
=
*
gelu_in
;
}
}
}
const
auto
gelu_shape
=
gelu
?
D_shape
:
std
::
vector
<
size_t
>
{
0
};
auto
te_pre_gelu_out
=
makeTransformerEngineTensor
(
get_data_ptr
(
pre_gelu_out
),
gelu_shape
,
gelu_type
);
// Workspace
auto
te_workspace
=
makeTransformerEngineTensor
(
workspace
.
data_ptr
(),
std
::
vector
<
size_t
>
{
workspaceSize
},
DType
::
kByte
);
// Set an external SM Margin to all the GEMMs.
// This comes in handy when DP is overlapped with GEMMs
const
int
device_id
=
at
::
cuda
::
current_device
();
const
int
sm_count
=
transformer_engine
::
cuda
::
sm_count
(
device_id
);
int
num_math_sms
=
sm_count
-
transformer_engine
::
getenv
<
int
>
(
"NVTE_EXT_MARGIN_SM"
,
sm_count
);
// Keep the swizzled scaling factor tensors alive during the GEMM.
std
::
vector
<
std
::
optional
<
at
::
Tensor
>>
swizzled_scale_inverses_list
;
auto
main_stream
=
at
::
cuda
::
getCurrentCUDAStream
();
if
(
A_tensor
.
numel
()
!=
0
&&
B_tensor
.
numel
()
!=
0
)
{
// Optionally swizzle the scaling factors
swizzled_scale_inverses_list
.
emplace_back
(
std
::
move
(
swizzle_scaling_factors
(
A_tensor
,
transa
)));
swizzled_scale_inverses_list
.
emplace_back
(
std
::
move
(
swizzle_scaling_factors
(
B_tensor
,
!
transb
)));
if
(
comm_overlap
)
{
NVTE_ERROR
(
"generic batchgemm not surpport comm_overlap."
);
}
else
{
// Launch GEMM
NVTE_SCOPED_GIL_RELEASE
({
nvte_cublas_batchgemm_v2
(
A_tensor
.
data
(),
B_tensor
.
data
(),
D_tensor
.
data
(),
bias_tensor
.
data
(),
te_pre_gelu_out
.
data
(),
transa
,
transb
,
grad
,
te_workspace
.
data
(),
accumulate
,
use_split_accumulator
,
num_math_sms
,
batch_count
,
main_stream
);
});
}
}
else
{
if
(
D_tensor
.
numel
()
!=
0
&&
!
accumulate
)
{
D_tensor
.
zero_
(
main_stream
);
}
if
(
bias
.
has_value
())
{
if
(
bias
->
numel
()
!=
0
&&
grad
)
{
bias_grad
->
zero_
();
}
}
}
// Pack outputs
std
::
vector
<
py
::
object
>
out
;
out
.
emplace_back
(
std
::
move
(
D
));
out
.
emplace_back
(
py
::
cast
(
bias_grad
));
if
(
gelu
&&
!
grad
)
{
out
.
emplace_back
(
py
::
cast
(
*
pre_gelu_out
));
}
else
{
out
.
emplace_back
(
py
::
none
());
}
if
(
extra_output
.
has_value
())
{
out
.
emplace_back
(
py
::
cast
(
extra_output
));
}
else
{
out
.
emplace_back
(
py
::
none
());
}
return
out
;
}
void
te_atomic_gemm
(
at
::
Tensor
A
,
at
::
Tensor
A_scale_inverse
,
DType
A_type
,
void
te_atomic_gemm
(
at
::
Tensor
A
,
at
::
Tensor
A_scale_inverse
,
DType
A_type
,
std
::
vector
<
int64_t
>
A_scaling_mode
,
bool
transa
,
at
::
Tensor
B
,
std
::
vector
<
int64_t
>
A_scaling_mode
,
bool
transa
,
at
::
Tensor
B
,
at
::
Tensor
B_scale_inverse
,
DType
B_type
,
std
::
vector
<
int64_t
>
B_scaling_mode
,
at
::
Tensor
B_scale_inverse
,
DType
B_type
,
std
::
vector
<
int64_t
>
B_scaling_mode
,
...
@@ -586,6 +454,274 @@ std::optional<std::vector<at::Tensor>> te_general_grouped_gemm(
...
@@ -586,6 +454,274 @@ std::optional<std::vector<at::Tensor>> te_general_grouped_gemm(
}
}
#ifdef USE_ROCM
#ifdef USE_ROCM
std
::
vector
<
py
::
object
>
generic_batchgemm
(
py
::
handle
A
,
bool
transa
,
py
::
handle
B
,
bool
transb
,
py
::
object
D
,
int
batch_count
,
py
::
handle
quantizer
,
std
::
optional
<
DType
>
out_dtype
,
MaybeTensor
bias
,
DType
bias_type
,
bool
gelu
,
MaybeTensor
gelu_in
,
bool
grad
,
at
::
Tensor
workspace
,
size_t
workspaceSize
,
bool
accumulate
,
bool
use_split_accumulator
,
CommOverlapCore
*
comm_overlap
,
std
::
optional
<
CommOverlapType
>
comm_type
,
MaybeTensor
extra_output
,
bool
bulk_overlap
)
{
// Input tensors
NVTE_CHECK
(
!
A
.
is_none
(),
"Tensor A has not been provided"
);
NVTE_CHECK
(
!
B
.
is_none
(),
"Tensor B has not been provided"
);
auto
none
=
py
::
none
();
TensorWrapper
A_tensor
=
makeTransformerEngineTensor
(
A
,
none
);
TensorWrapper
B_tensor
=
makeTransformerEngineTensor
(
B
,
none
);
const
bool
low_precision
=
detail
::
is_low_precision
(
A_tensor
.
dtype
())
||
detail
::
is_low_precision
(
B_tensor
.
dtype
());
// Check tensor dimensions
const
auto
&
A_shape
=
A_tensor
.
shape
();
const
auto
&
B_shape
=
B_tensor
.
shape
();
const
auto
&
D_shape
=
detail
::
getGemmOutputShape
(
A_shape
,
transa
,
B_shape
,
transb
);
NVTE_CHECK
(
A_shape
.
ndim
>=
1
,
"Tensor A needs to have at least 1 dimension"
);
NVTE_CHECK
(
B_shape
.
ndim
>=
1
,
"Tensor B needs to have at least 1 dimension"
);
// Output tensor
TensorWrapper
D_tensor
;
if
(
D
.
is_none
())
{
NVTE_ERROR
(
"generic batchgemm D must be not None."
);
}
else
{
D_tensor
=
makeTransformerEngineTensor
(
D
,
quantizer
);
if
(
out_dtype
)
{
NVTE_CHECK
(
*
out_dtype
==
D_tensor
.
dtype
(),
"GEMM output has invalid dtype (expected "
,
static_cast
<
int
>
(
*
out_dtype
),
", found "
,
static_cast
<
int
>
(
D_tensor
.
dtype
()),
")"
);
}
}
// Bias tensor
TensorWrapper
bias_tensor
;
MaybeTensor
bias_grad
=
std
::
nullopt
;
if
(
bias
.
has_value
())
{
if
(
grad
)
{
auto
opts
=
torch
::
TensorOptions
().
dtype
(
GetATenDType
(
D_tensor
.
dtype
())).
device
(
torch
::
kCUDA
);
bias_grad
=
at
::
empty
({
static_cast
<
int64_t
>
(
B_shape
.
data
[
B_shape
.
ndim
-
1
])},
opts
);
bias_tensor
=
makeTransformerEngineTensor
(
*
bias_grad
);
}
else
{
if
(
!
bias
->
is_contiguous
())
{
bias
=
bias
->
contiguous
();
}
bias_tensor
=
makeTransformerEngineTensor
(
*
bias
);
}
}
// Activation input tensor
MaybeTensor
pre_gelu_out
=
std
::
nullopt
;
DType
gelu_type
=
low_precision
?
bias_type
:
D_tensor
.
dtype
();
if
(
gelu
)
{
if
(
!
grad
)
{
auto
dtype
=
GetATenDType
(
gelu_type
);
auto
opts
=
torch
::
TensorOptions
().
dtype
(
dtype
).
device
(
torch
::
kCUDA
);
std
::
vector
<
int64_t
>
torch_shape
;
for
(
auto
v
:
D_shape
)
{
torch_shape
.
push_back
(
v
);
}
pre_gelu_out
=
at
::
empty
(
torch_shape
,
opts
);
}
else
{
if
(
gelu_in
.
has_value
())
{
pre_gelu_out
=
*
gelu_in
;
}
}
}
const
auto
gelu_shape
=
gelu
?
D_shape
:
std
::
vector
<
size_t
>
{
0
};
auto
te_pre_gelu_out
=
makeTransformerEngineTensor
(
get_data_ptr
(
pre_gelu_out
),
gelu_shape
,
gelu_type
);
// Workspace
auto
te_workspace
=
makeTransformerEngineTensor
(
workspace
.
data_ptr
(),
std
::
vector
<
size_t
>
{
workspaceSize
},
DType
::
kByte
);
// Set an external SM Margin to all the GEMMs.
// This comes in handy when DP is overlapped with GEMMs
const
int
device_id
=
at
::
cuda
::
current_device
();
const
int
sm_count
=
transformer_engine
::
cuda
::
sm_count
(
device_id
);
int
num_math_sms
=
sm_count
-
transformer_engine
::
getenv
<
int
>
(
"NVTE_EXT_MARGIN_SM"
,
sm_count
);
// Keep the swizzled scaling factor tensors alive during the GEMM.
std
::
vector
<
std
::
optional
<
at
::
Tensor
>>
swizzled_scale_inverses_list
;
auto
main_stream
=
at
::
cuda
::
getCurrentCUDAStream
();
if
(
A_tensor
.
numel
()
!=
0
&&
B_tensor
.
numel
()
!=
0
)
{
// Optionally swizzle the scaling factors
swizzled_scale_inverses_list
.
emplace_back
(
std
::
move
(
swizzle_scaling_factors
(
A_tensor
,
transa
)));
swizzled_scale_inverses_list
.
emplace_back
(
std
::
move
(
swizzle_scaling_factors
(
B_tensor
,
!
transb
)));
if
(
comm_overlap
)
{
NVTE_ERROR
(
"generic batchgemm not surpport comm_overlap."
);
}
else
{
// Launch GEMM
NVTE_SCOPED_GIL_RELEASE
({
nvte_cublas_batchgemm_v2
(
A_tensor
.
data
(),
B_tensor
.
data
(),
D_tensor
.
data
(),
bias_tensor
.
data
(),
te_pre_gelu_out
.
data
(),
transa
,
transb
,
grad
,
te_workspace
.
data
(),
accumulate
,
use_split_accumulator
,
num_math_sms
,
batch_count
,
main_stream
);
});
}
}
else
{
if
(
D_tensor
.
numel
()
!=
0
&&
!
accumulate
)
{
D_tensor
.
zero_
(
main_stream
);
}
if
(
bias
.
has_value
())
{
if
(
bias
->
numel
()
!=
0
&&
grad
)
{
bias_grad
->
zero_
();
}
}
}
// Pack outputs
std
::
vector
<
py
::
object
>
out
;
out
.
emplace_back
(
std
::
move
(
D
));
out
.
emplace_back
(
py
::
cast
(
bias_grad
));
if
(
gelu
&&
!
grad
)
{
out
.
emplace_back
(
py
::
cast
(
*
pre_gelu_out
));
}
else
{
out
.
emplace_back
(
py
::
none
());
}
if
(
extra_output
.
has_value
())
{
out
.
emplace_back
(
py
::
cast
(
extra_output
));
}
else
{
out
.
emplace_back
(
py
::
none
());
}
return
out
;
}
std
::
vector
<
py
::
object
>
tensorwise_int8_batchgemm
(
py
::
handle
A
,
bool
transa
,
py
::
handle
B
,
bool
transb
,
py
::
handle
A_scales
,
py
::
handle
B_scales
,
py
::
object
D
,
int
batch_count
,
py
::
handle
quantizer
,
std
::
optional
<
DType
>
out_dtype
,
MaybeTensor
bias
,
DType
bias_type
,
bool
gelu
,
MaybeTensor
gelu_in
,
bool
grad
,
at
::
Tensor
workspace
,
size_t
workspaceSize
,
bool
accumulate
,
bool
use_split_accumulator
,
CommOverlapCore
*
comm_overlap
,
std
::
optional
<
CommOverlapType
>
comm_type
,
MaybeTensor
extra_output
,
bool
bulk_overlap
)
{
// Input tensors
NVTE_CHECK
(
!
A
.
is_none
(),
"Tensor A has not been provided"
);
NVTE_CHECK
(
!
B
.
is_none
(),
"Tensor B has not been provided"
);
NVTE_CHECK
(
!
A_scales
.
is_none
(),
"Tensor A has not been provided"
);
NVTE_CHECK
(
!
B_scales
.
is_none
(),
"Tensor B has not been provided"
);
auto
none
=
py
::
none
();
TensorWrapper
A_tensor
=
makeTransformerEngineTensor
(
A
,
none
);
TensorWrapper
B_tensor
=
makeTransformerEngineTensor
(
B
,
none
);
TensorWrapper
A_scales_tensor
=
makeTransformerEngineTensor
(
A_scales
,
none
);
TensorWrapper
B_scales_tensor
=
makeTransformerEngineTensor
(
B_scales
,
none
);
const
bool
low_precision
=
detail
::
is_low_precision
(
A_tensor
.
dtype
())
||
detail
::
is_low_precision
(
B_tensor
.
dtype
());
// Check tensor dimensions
const
auto
&
A_shape
=
A_tensor
.
shape
();
const
auto
&
B_shape
=
B_tensor
.
shape
();
const
auto
&
D_shape
=
detail
::
getGemmOutputShape
(
A_shape
,
transa
,
B_shape
,
transb
);
NVTE_CHECK
(
A_shape
.
ndim
>=
1
,
"Tensor A needs to have at least 1 dimension"
);
NVTE_CHECK
(
B_shape
.
ndim
>=
1
,
"Tensor B needs to have at least 1 dimension"
);
// Output tensor
TensorWrapper
D_tensor
;
if
(
D
.
is_none
())
{
NVTE_ERROR
(
"tensorwise int8 batchgemm D must be not None."
);
}
else
{
D_tensor
=
makeTransformerEngineTensor
(
D
,
quantizer
);
if
(
out_dtype
)
{
NVTE_CHECK
(
*
out_dtype
==
D_tensor
.
dtype
(),
"GEMM output has invalid dtype (expected "
,
static_cast
<
int
>
(
*
out_dtype
),
", found "
,
static_cast
<
int
>
(
D_tensor
.
dtype
()),
")"
);
}
}
// Bias tensor
TensorWrapper
bias_tensor
;
MaybeTensor
bias_grad
=
std
::
nullopt
;
if
(
bias
.
has_value
())
{
if
(
grad
)
{
auto
opts
=
torch
::
TensorOptions
().
dtype
(
GetATenDType
(
D_tensor
.
dtype
())).
device
(
torch
::
kCUDA
);
bias_grad
=
at
::
empty
({
static_cast
<
int64_t
>
(
B_shape
.
data
[
B_shape
.
ndim
-
1
])},
opts
);
bias_tensor
=
makeTransformerEngineTensor
(
*
bias_grad
);
}
else
{
if
(
!
bias
->
is_contiguous
())
{
bias
=
bias
->
contiguous
();
}
bias_tensor
=
makeTransformerEngineTensor
(
*
bias
);
}
}
// Activation input tensor
MaybeTensor
pre_gelu_out
=
std
::
nullopt
;
DType
gelu_type
=
low_precision
?
bias_type
:
D_tensor
.
dtype
();
if
(
gelu
)
{
if
(
!
grad
)
{
auto
dtype
=
GetATenDType
(
gelu_type
);
auto
opts
=
torch
::
TensorOptions
().
dtype
(
dtype
).
device
(
torch
::
kCUDA
);
std
::
vector
<
int64_t
>
torch_shape
;
for
(
auto
v
:
D_shape
)
{
torch_shape
.
push_back
(
v
);
}
pre_gelu_out
=
at
::
empty
(
torch_shape
,
opts
);
}
else
{
if
(
gelu_in
.
has_value
())
{
pre_gelu_out
=
*
gelu_in
;
}
}
}
const
auto
gelu_shape
=
gelu
?
D_shape
:
std
::
vector
<
size_t
>
{
0
};
auto
te_pre_gelu_out
=
makeTransformerEngineTensor
(
get_data_ptr
(
pre_gelu_out
),
gelu_shape
,
gelu_type
);
// Workspace
auto
te_workspace
=
makeTransformerEngineTensor
(
workspace
.
data_ptr
(),
std
::
vector
<
size_t
>
{
workspaceSize
},
DType
::
kByte
);
// Set an external SM Margin to all the GEMMs.
// This comes in handy when DP is overlapped with GEMMs
const
int
device_id
=
at
::
cuda
::
current_device
();
const
int
sm_count
=
transformer_engine
::
cuda
::
sm_count
(
device_id
);
int
num_math_sms
=
sm_count
-
transformer_engine
::
getenv
<
int
>
(
"NVTE_EXT_MARGIN_SM"
,
sm_count
);
// Keep the swizzled scaling factor tensors alive during the GEMM.
std
::
vector
<
std
::
optional
<
at
::
Tensor
>>
swizzled_scale_inverses_list
;
auto
main_stream
=
at
::
cuda
::
getCurrentCUDAStream
();
if
(
A_tensor
.
numel
()
!=
0
&&
B_tensor
.
numel
()
!=
0
)
{
// Optionally swizzle the scaling factors
swizzled_scale_inverses_list
.
emplace_back
(
std
::
move
(
swizzle_scaling_factors
(
A_tensor
,
transa
)));
swizzled_scale_inverses_list
.
emplace_back
(
std
::
move
(
swizzle_scaling_factors
(
B_tensor
,
!
transb
)));
if
(
comm_overlap
)
{
NVTE_ERROR
(
"generic batchgemm not surpport comm_overlap."
);
}
else
{
// Launch GEMM
NVTE_SCOPED_GIL_RELEASE
({
nvte_cublas_batchgemm_v3
(
A_tensor
.
data
(),
B_tensor
.
data
(),
A_scales_tensor
.
data
(),
B_scales_tensor
.
data
(),
D_tensor
.
data
(),
bias_tensor
.
data
(),
te_pre_gelu_out
.
data
(),
transa
,
transb
,
grad
,
te_workspace
.
data
(),
accumulate
,
use_split_accumulator
,
num_math_sms
,
batch_count
,
main_stream
);
});
}
}
else
{
if
(
D_tensor
.
numel
()
!=
0
&&
!
accumulate
)
{
D_tensor
.
zero_
(
main_stream
);
}
if
(
bias
.
has_value
())
{
if
(
bias
->
numel
()
!=
0
&&
grad
)
{
bias_grad
->
zero_
();
}
}
}
// Pack outputs
std
::
vector
<
py
::
object
>
out
;
out
.
emplace_back
(
std
::
move
(
D
));
out
.
emplace_back
(
py
::
cast
(
bias_grad
));
if
(
gelu
&&
!
grad
)
{
out
.
emplace_back
(
py
::
cast
(
*
pre_gelu_out
));
}
else
{
out
.
emplace_back
(
py
::
none
());
}
if
(
extra_output
.
has_value
())
{
out
.
emplace_back
(
py
::
cast
(
extra_output
));
}
else
{
out
.
emplace_back
(
py
::
none
());
}
return
out
;
}
void
te_batchgemm
(
std
::
vector
<
at
::
Tensor
>
A
,
at
::
Tensor
A_scale_inverse
,
int
A_offset
,
void
te_batchgemm
(
std
::
vector
<
at
::
Tensor
>
A
,
at
::
Tensor
A_scale_inverse
,
int
A_offset
,
transformer_engine
::
DType
A_type
,
bool
transa
,
std
::
vector
<
at
::
Tensor
>
B
,
transformer_engine
::
DType
A_type
,
bool
transa
,
std
::
vector
<
at
::
Tensor
>
B
,
at
::
Tensor
B_scale_inverse
,
int
B_offset
,
transformer_engine
::
DType
B_type
,
at
::
Tensor
B_scale_inverse
,
int
B_offset
,
transformer_engine
::
DType
B_type
,
...
...
transformer_engine/pytorch/csrc/extensions/pybind.cpp
View file @
7a923605
...
@@ -110,13 +110,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
...
@@ -110,13 +110,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
py
::
arg
(
"workspace_size"
),
py
::
arg
(
"accumulate"
),
py
::
arg
(
"use_split_accumulator"
),
py
::
arg
(
"workspace_size"
),
py
::
arg
(
"accumulate"
),
py
::
arg
(
"use_split_accumulator"
),
py
::
arg
(
"comm_overlap"
)
=
nullptr
,
py
::
arg
(
"comm_type"
)
=
std
::
nullopt
,
py
::
arg
(
"comm_overlap"
)
=
nullptr
,
py
::
arg
(
"comm_type"
)
=
std
::
nullopt
,
py
::
arg
(
"extra_output"
)
=
std
::
nullopt
,
py
::
arg
(
"bulk_overlap"
)
=
false
);
py
::
arg
(
"extra_output"
)
=
std
::
nullopt
,
py
::
arg
(
"bulk_overlap"
)
=
false
);
m
.
def
(
"generic_batchgemm"
,
transformer_engine
::
pytorch
::
generic_batchgemm
,
"Compute Batched GEMM (matrix-matrix multiply)"
,
py
::
arg
(
"A"
),
py
::
arg
(
"transA"
),
py
::
arg
(
"B"
),
py
::
arg
(
"transB"
),
py
::
arg
(
"D"
),
py
::
arg
(
"batchcount"
),
py
::
arg
(
"quantizer"
),
py
::
arg
(
"output_dtype"
),
py
::
arg
(
"bias"
),
py
::
arg
(
"bias_type"
),
py
::
arg
(
"gelu"
),
py
::
arg
(
"gelu_in"
),
py
::
arg
(
"grad"
),
py
::
arg
(
"workspace"
),
py
::
arg
(
"workspace_size"
),
py
::
arg
(
"accumulate"
),
py
::
arg
(
"use_split_accumulator"
),
py
::
arg
(
"comm_overlap"
)
=
nullptr
,
py
::
arg
(
"comm_type"
)
=
std
::
nullopt
,
py
::
arg
(
"extra_output"
)
=
std
::
nullopt
,
py
::
arg
(
"bulk_overlap"
)
=
false
);
m
.
def
(
"gelu"
,
transformer_engine
::
pytorch
::
gelu
,
"GeLU activation"
,
py
::
arg
(
"input"
),
m
.
def
(
"gelu"
,
transformer_engine
::
pytorch
::
gelu
,
"GeLU activation"
,
py
::
arg
(
"input"
),
py
::
arg
(
"quantizer"
));
py
::
arg
(
"quantizer"
));
m
.
def
(
"relu"
,
transformer_engine
::
pytorch
::
relu
,
"ReLU activation"
,
py
::
arg
(
"input"
),
m
.
def
(
"relu"
,
transformer_engine
::
pytorch
::
relu
,
"ReLU activation"
,
py
::
arg
(
"input"
),
...
@@ -213,6 +206,20 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
...
@@ -213,6 +206,20 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m
.
def
(
"te_general_grouped_gemm"
,
&
transformer_engine
::
pytorch
::
te_general_grouped_gemm
,
m
.
def
(
"te_general_grouped_gemm"
,
&
transformer_engine
::
pytorch
::
te_general_grouped_gemm
,
"Grouped GEMM"
);
"Grouped GEMM"
);
#ifdef USE_ROCM
#ifdef USE_ROCM
m
.
def
(
"generic_batchgemm"
,
transformer_engine
::
pytorch
::
generic_batchgemm
,
"Compute Batched GEMM (matrix-matrix multiply)"
,
py
::
arg
(
"A"
),
py
::
arg
(
"transA"
),
py
::
arg
(
"B"
),
py
::
arg
(
"transB"
),
py
::
arg
(
"D"
),
py
::
arg
(
"batchcount"
),
py
::
arg
(
"quantizer"
),
py
::
arg
(
"output_dtype"
),
py
::
arg
(
"bias"
),
py
::
arg
(
"bias_type"
),
py
::
arg
(
"gelu"
),
py
::
arg
(
"gelu_in"
),
py
::
arg
(
"grad"
),
py
::
arg
(
"workspace"
),
py
::
arg
(
"workspace_size"
),
py
::
arg
(
"accumulate"
),
py
::
arg
(
"use_split_accumulator"
),
py
::
arg
(
"comm_overlap"
)
=
nullptr
,
py
::
arg
(
"comm_type"
)
=
std
::
nullopt
,
py
::
arg
(
"extra_output"
)
=
std
::
nullopt
,
py
::
arg
(
"bulk_overlap"
)
=
false
);
m
.
def
(
"tensorwise_int8_batchgemm"
,
transformer_engine
::
pytorch
::
tensorwise_int8_batchgemm
,
"Compute Tensorwise Int8 Batched GEMM (matrix-matrix multiply)"
,
py
::
arg
(
"A"
),
py
::
arg
(
"transA"
),
py
::
arg
(
"B"
),
py
::
arg
(
"transB"
),
py
::
arg
(
"A_scales"
),
py
::
arg
(
"B_scales"
),
py
::
arg
(
"D"
),
py
::
arg
(
"batchcount"
),
py
::
arg
(
"quantizer"
),
py
::
arg
(
"output_dtype"
),
py
::
arg
(
"bias"
),
py
::
arg
(
"bias_type"
),
py
::
arg
(
"gelu"
),
py
::
arg
(
"gelu_in"
),
py
::
arg
(
"grad"
),
py
::
arg
(
"workspace"
),
py
::
arg
(
"workspace_size"
),
py
::
arg
(
"accumulate"
),
py
::
arg
(
"use_split_accumulator"
),
py
::
arg
(
"comm_overlap"
)
=
nullptr
,
py
::
arg
(
"comm_type"
)
=
std
::
nullopt
,
py
::
arg
(
"extra_output"
)
=
std
::
nullopt
,
py
::
arg
(
"bulk_overlap"
)
=
false
);
m
.
def
(
"te_batchgemm_ts"
,
&
transformer_engine
::
pytorch
::
te_batchgemm_ts
,
"Batched GEMM"
);
/// rocblas
m
.
def
(
"te_batchgemm_ts"
,
&
transformer_engine
::
pytorch
::
te_batchgemm_ts
,
"Batched GEMM"
);
/// rocblas
#endif
#endif
m
.
def
(
"fp8_transpose"
,
&
transformer_engine
::
pytorch
::
fp8_transpose
,
"Transpose with FP8 I/O"
,
m
.
def
(
"fp8_transpose"
,
&
transformer_engine
::
pytorch
::
fp8_transpose
,
"Transpose with FP8 I/O"
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment