Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
4a013bd5
Commit
4a013bd5
authored
Aug 06, 2025
by
yuguo
Browse files
[DCU] fix channelwise train oom bug
parent
ddfbdaf4
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
22 additions
and
18 deletions
+22
-18
transformer_engine/common/gemm/rocm_gemm.cu
transformer_engine/common/gemm/rocm_gemm.cu
+15
-15
transformer_engine/pytorch/cpp_extensions/gemm.py
transformer_engine/pytorch/cpp_extensions/gemm.py
+7
-3
No files found.
transformer_engine/common/gemm/rocm_gemm.cu
View file @
4a013bd5
...
@@ -1479,8 +1479,8 @@ private:
...
@@ -1479,8 +1479,8 @@ private:
};
};
// Define a static userArgs manager
// Define a static userArgs manager
static
userArgsManager
UAManager
;
//
static userArgsManager UAManager;
static
d_userArgsManager
d_UAManager
;
//
static d_userArgsManager d_UAManager;
void
hipblaslt_goupedgemm
(
std
::
vector
<
const
Tensor
*>&
inputA
,
std
::
vector
<
const
Tensor
*>&
inputB
,
std
::
vector
<
Tensor
*>&
outputD
,
void
hipblaslt_goupedgemm
(
std
::
vector
<
const
Tensor
*>&
inputA
,
std
::
vector
<
const
Tensor
*>&
inputB
,
std
::
vector
<
Tensor
*>&
outputD
,
std
::
vector
<
int64_t
>&
m
,
std
::
vector
<
int64_t
>&
n
,
std
::
vector
<
int64_t
>&
k
,
std
::
vector
<
int64_t
>&
b
,
hipblasOperation_t
transa
,
hipblasOperation_t
transb
,
std
::
vector
<
int64_t
>&
m
,
std
::
vector
<
int64_t
>&
n
,
std
::
vector
<
int64_t
>&
k
,
std
::
vector
<
int64_t
>&
b
,
hipblasOperation_t
transa
,
hipblasOperation_t
transb
,
...
@@ -1489,10 +1489,10 @@ void hipblaslt_goupedgemm(std::vector<const Tensor*>& inputA, std::vector<const
...
@@ -1489,10 +1489,10 @@ void hipblaslt_goupedgemm(std::vector<const Tensor*>& inputA, std::vector<const
// Check compute_stream_offset valid.
// Check compute_stream_offset valid.
NVTE_CHECK
(
compute_stream_offset
>=
-
1
&&
compute_stream_offset
<
compute_num_streams
);
NVTE_CHECK
(
compute_stream_offset
>=
-
1
&&
compute_stream_offset
<
compute_num_streams
);
int
device_id
;
//
int device_id;
hipGetDevice
(
&
device_id
);
//
hipGetDevice(&device_id);
hipblaslt_ext
::
UserArguments
*
userArgs
=
UAManager
.
get
(
device_id
,
m
.
size
());
//
hipblaslt_ext::UserArguments* userArgs = UAManager.get(device_id, m.size());
hipblaslt_ext
::
UserArguments
*
d_userArgs
=
d_UAManager
.
get
(
device_id
,
m
.
size
());
//
hipblaslt_ext::UserArguments* d_userArgs = d_UAManager.get(device_id, m.size());
// hipblaslt_ext::UserArguments* userArgs;
// hipblaslt_ext::UserArguments* userArgs;
// NVTE_CHECK_CUDA(hipHostMalloc(&userArgs, m.size() * sizeof(hipblaslt_ext::UserArguments)));
// NVTE_CHECK_CUDA(hipHostMalloc(&userArgs, m.size() * sizeof(hipblaslt_ext::UserArguments)));
...
@@ -1566,20 +1566,20 @@ void hipblaslt_goupedgemm(std::vector<const Tensor*>& inputA, std::vector<const
...
@@ -1566,20 +1566,20 @@ void hipblaslt_goupedgemm(std::vector<const Tensor*>& inputA, std::vector<const
}
}
// Get the default values from the grouepdgemm object
// Get the default values from the grouepdgemm object
groupedgemm
.
getDefaultValueForDeviceUserArguments
(
userArgs
);
//
groupedgemm.getDefaultValueForDeviceUserArguments(userArgs);
// Copy them to device memory
// Copy them to device memory
// hipblaslt_ext::UserArguments* d_userArgs;
// hipblaslt_ext::UserArguments* d_userArgs;
// NVTE_CHECK_CUDA(hipMallocAsync(&d_userArgs, m.size() * sizeof(hipblaslt_ext::UserArguments), stream));
// NVTE_CHECK_CUDA(hipMallocAsync(&d_userArgs, m.size() * sizeof(hipblaslt_ext::UserArguments), stream));
NVTE_CHECK_CUDA
(
hipMemcpyAsync
(
d_userArgs
,
//
NVTE_CHECK_CUDA(hipMemcpyAsync(d_userArgs,
userArgs
,
//
userArgs,
m
.
size
()
*
sizeof
(
hipblaslt_ext
::
UserArguments
),
//
m.size() * sizeof(hipblaslt_ext::UserArguments),
hipMemcpyHostToDevice
,
stream
));
//
hipMemcpyHostToDevice, stream));
// Make sure to initialize everytime the algo changes
// Make sure to initialize everytime the algo changes
NVTE_CHECK_HIPBLASLT
(
groupedgemm
.
initialize
(
heuristicResult
[
0
].
algo
,
workspace
));
//
NVTE_CHECK_HIPBLASLT(groupedgemm.initialize(heuristicResult[0].algo, workspace));
NVTE_CHECK_HIPBLASLT
(
groupedgemm
.
run
(
d_userArgs
,
stream
));
//
NVTE_CHECK_HIPBLASLT(groupedgemm.run(d_userArgs, stream));
//
NVTE_CHECK_HIPBLASLT(groupedgemm.initialize(heuristicResult[0].algo, workspace, false, stream));
NVTE_CHECK_HIPBLASLT
(
groupedgemm
.
initialize
(
heuristicResult
[
0
].
algo
,
workspace
,
false
,
stream
));
//
NVTE_CHECK_HIPBLASLT(groupedgemm.run(stream));
NVTE_CHECK_HIPBLASLT
(
groupedgemm
.
run
(
stream
));
// NVTE_CHECK_CUDA(hipFreeAsync(d_userArgs, stream));
// NVTE_CHECK_CUDA(hipFreeAsync(d_userArgs, stream));
// NVTE_CHECK_CUDA(hipFree(userArgs));
// NVTE_CHECK_CUDA(hipFree(userArgs));
...
...
transformer_engine/pytorch/cpp_extensions/gemm.py
View file @
4a013bd5
...
@@ -190,6 +190,7 @@ def general_gemm(
...
@@ -190,6 +190,7 @@ def general_gemm(
if
layout
==
"TN"
:
if
layout
==
"TN"
:
assert
out_dtype
is
torch
.
bfloat16
assert
out_dtype
is
torch
.
bfloat16
out_shape
=
B
.
_data
.
shape
[:
-
1
]
+
(
A
.
_data
.
shape
[
0
],
)
x_int8
,
x_scales
=
per_token_quant_fp8_to_int8
(
B
.
_data
.
view
(
dtype
=
TE_DType_To_Torch
[
B
.
_fp8_dtype
]),
B
.
_scale_inv
,
False
)
x_int8
,
x_scales
=
per_token_quant_fp8_to_int8
(
B
.
_data
.
view
(
dtype
=
TE_DType_To_Torch
[
B
.
_fp8_dtype
]),
B
.
_scale_inv
,
False
)
w_int8
,
w_scales
=
per_token_quant_fp8_to_int8
(
A
.
_data
.
view
(
dtype
=
TE_DType_To_Torch
[
A
.
_fp8_dtype
]),
A
.
_scale_inv
,
False
)
w_int8
,
w_scales
=
per_token_quant_fp8_to_int8
(
A
.
_data
.
view
(
dtype
=
TE_DType_To_Torch
[
A
.
_fp8_dtype
]),
A
.
_scale_inv
,
False
)
...
@@ -212,10 +213,11 @@ def general_gemm(
...
@@ -212,10 +213,11 @@ def general_gemm(
use_split_accumulator
,
use_split_accumulator
,
)[
0
]
)[
0
]
y
=
channelwise_dequantize_transB
(
x_scales
,
w_scales
,
y_int32
)
y
=
channelwise_dequantize_transB
(
x_scales
,
w_scales
,
y_int32
)
return
y
,
None
,
None
,
None
return
y
.
view
(
out_shape
)
,
None
,
None
,
None
elif
layout
==
"NN"
:
elif
layout
==
"NN"
:
assert
out_dtype
is
torch
.
bfloat16
assert
out_dtype
is
torch
.
bfloat16
dx_shape
=
B
.
_data
.
shape
[:
-
1
]
+
(
A
.
_data
.
shape
[
-
1
],
)
dy_int8
,
dy_scales
=
per_token_quant_fp8_to_int8
(
B
.
_data
.
view
(
dtype
=
TE_DType_To_Torch
[
B
.
_fp8_dtype
]),
B
.
_scale_inv
,
False
)
dy_int8
,
dy_scales
=
per_token_quant_fp8_to_int8
(
B
.
_data
.
view
(
dtype
=
TE_DType_To_Torch
[
B
.
_fp8_dtype
]),
B
.
_scale_inv
,
False
)
w_int8
,
w_scales
=
per_token_quant_fp8_to_int8_opt
(
A
.
_data
.
view
(
dtype
=
TE_DType_To_Torch
[
A
.
_fp8_dtype
]),
A
.
_scale_inv
,
False
)
w_int8
,
w_scales
=
per_token_quant_fp8_to_int8_opt
(
A
.
_data
.
view
(
dtype
=
TE_DType_To_Torch
[
A
.
_fp8_dtype
]),
A
.
_scale_inv
,
False
)
...
@@ -238,7 +240,7 @@ def general_gemm(
...
@@ -238,7 +240,7 @@ def general_gemm(
use_split_accumulator
,
use_split_accumulator
,
)[
0
]
)[
0
]
dx
=
channelwise_dequantize
(
dy_scales
,
w_scales
,
dx_int32
)
dx
=
channelwise_dequantize
(
dy_scales
,
w_scales
,
dx_int32
)
return
dx
,
None
,
None
,
None
return
dx
.
view
(
dx_shape
)
,
None
,
None
,
None
elif
layout
==
"NT"
:
elif
layout
==
"NT"
:
assert
out_dtype
is
torch
.
bfloat16
or
out_dtype
is
torch
.
float32
assert
out_dtype
is
torch
.
bfloat16
or
out_dtype
is
torch
.
float32
...
@@ -475,7 +477,8 @@ def general_grouped_gemm(
...
@@ -475,7 +477,8 @@ def general_grouped_gemm(
out
[
0
]
=
out
[
0
].
view
(
num_gemms
,
seq_len
,
out
[
0
].
size
(
-
1
))
out
[
0
]
=
out
[
0
].
view
(
num_gemms
,
seq_len
,
out
[
0
].
size
(
-
1
))
for
i
in
range
(
num_gemms
):
for
i
in
range
(
num_gemms
):
out
[
0
][
i
]
=
channelwise_dequantize_transB
(
scales_x_list
[
i
],
scales_w_list
[
i
],
y_int32
[
i
])
out
[
0
][
i
]
=
channelwise_dequantize_transB
(
scales_x_list
[
i
],
scales_w_list
[
i
],
y_int32
[
i
])
return
out
.
view
(
-
1
,
out
[
0
].
size
(
-
1
)),
bias
,
gelu_input
out
[
0
]
=
out
[
0
].
view
(
-
1
,
out
[
0
].
size
(
-
1
))
return
out
,
bias
,
gelu_input
elif
layout
==
"NN"
:
elif
layout
==
"NN"
:
assert
out_dtype
is
torch
.
bfloat16
assert
out_dtype
is
torch
.
bfloat16
...
@@ -522,6 +525,7 @@ def general_grouped_gemm(
...
@@ -522,6 +525,7 @@ def general_grouped_gemm(
out
[
0
]
=
out
[
0
].
view
(
num_gemms
,
seq_len
,
out
[
0
].
size
(
-
1
))
out
[
0
]
=
out
[
0
].
view
(
num_gemms
,
seq_len
,
out
[
0
].
size
(
-
1
))
for
i
in
range
(
num_gemms
):
for
i
in
range
(
num_gemms
):
out
[
0
][
i
]
=
channelwise_dequantize
(
scales_dout_list
[
i
],
scales_w_list
[
i
],
dx_int32
[
i
])
out
[
0
][
i
]
=
channelwise_dequantize
(
scales_dout_list
[
i
],
scales_w_list
[
i
],
dx_int32
[
i
])
out
[
0
]
=
out
[
0
].
view
(
-
1
,
out
[
0
].
size
(
-
1
))
return
out
,
bias
,
gelu_input
return
out
,
bias
,
gelu_input
elif
layout
==
"NT"
:
elif
layout
==
"NT"
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment