Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
3cdceb87
Commit
3cdceb87
authored
Aug 13, 2025
by
wenjh
Browse files
Delete tmpArgs in groupedgemm
Signed-off-by:
wenjh
<
wenjh@sugon.com
>
parent
1f97aebb
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
3 additions
and
39 deletions
+3
-39
transformer_engine/common/gemm/rocm_gemm.cu
transformer_engine/common/gemm/rocm_gemm.cu
+3
-39
No files found.
transformer_engine/common/gemm/rocm_gemm.cu
View file @
3cdceb87
...
...
@@ -1478,45 +1478,9 @@ private:
std
::
mutex
mutex_
;
};
class
tmp_userArgsManager
{
public:
tmp_userArgsManager
()
{}
~
tmp_userArgsManager
()
{
// Release all userArgs when the manager is destroyed
for
(
auto
&
device_pair
:
tmp_userArgs_map_
)
{
hipFree
(
device_pair
.
second
);
// Only one userArgs per device
}
}
// Get a userArgs for the given device (creates if necessary)
void
*
get
(
int
device_id
,
size_t
size
)
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
mutex_
);
// Check if the userArgs for this device exists
auto
device_it
=
tmp_userArgs_map_
.
find
(
device_id
);
if
(
device_it
!=
tmp_userArgs_map_
.
end
())
{
return
device_it
->
second
;
}
// Create a new userArgs for this device if it doesn't exist
void
*
tmp_userArgs
;
NVTE_CHECK_CUDA
(
hipHostMalloc
(
&
tmp_userArgs
,
size
));
// Store the userArgs in the map for this device
tmp_userArgs_map_
[
device_id
]
=
tmp_userArgs
;
return
tmp_userArgs
;
}
private:
std
::
unordered_map
<
int
,
void
*>
tmp_userArgs_map_
;
// Map from device_id to hipblasHandle
std
::
mutex
mutex_
;
};
// Define a static userArgs manager
static
userArgsManager
UAManager
;
static
d_userArgsManager
d_UAManager
;
static
tmp_userArgsManager
tmp_UAManager
;
void
hipblaslt_goupedgemm
(
std
::
vector
<
const
Tensor
*>&
inputA
,
std
::
vector
<
const
Tensor
*>&
inputB
,
std
::
vector
<
Tensor
*>&
outputD
,
std
::
vector
<
int64_t
>&
m
,
std
::
vector
<
int64_t
>&
n
,
std
::
vector
<
int64_t
>&
k
,
std
::
vector
<
int64_t
>&
b
,
hipblasOperation_t
transa
,
hipblasOperation_t
transb
,
...
...
@@ -1529,7 +1493,6 @@ void hipblaslt_goupedgemm(std::vector<const Tensor*>& inputA, std::vector<const
hipGetDevice
(
&
device_id
);
hipblaslt_ext
::
UserArguments
*
userArgs
=
UAManager
.
get
(
device_id
,
m
.
size
());
hipblaslt_ext
::
UserArguments
*
d_userArgs
=
d_UAManager
.
get
(
device_id
,
m
.
size
());
void
*
tmp_userArgs
=
tmp_UAManager
.
get
(
device_id
,
32768
);
// hipblaslt_ext::UserArguments* userArgs;
// NVTE_CHECK_CUDA(hipHostMalloc(&userArgs, m.size() * sizeof(hipblaslt_ext::UserArguments)));
...
...
@@ -1573,8 +1536,7 @@ void hipblaslt_goupedgemm(std::vector<const Tensor*>& inputA, std::vector<const
B_type
,
D_type
,
D_type
,
computeType
,
tmp_userArgs
);
computeType
);
std
::
vector
<
hipblaslt_ext
::
GemmEpilogue
>
epilogue
{
hipblaslt_ext
::
...
...
@@ -1605,6 +1567,7 @@ void hipblaslt_goupedgemm(std::vector<const Tensor*>& inputA, std::vector<const
// Make sure to initialize everytime the algo changes
NVTE_CHECK_HIPBLASLT
(
groupedgemm
.
initialize
(
heuristicResult
[
0
].
algo
,
workspace
));
// Get the default values from the grouepdgemm object
groupedgemm
.
getDefaultValueForDeviceUserArguments
(
userArgs
);
// Copy them to device memory
...
...
@@ -1614,6 +1577,7 @@ void hipblaslt_goupedgemm(std::vector<const Tensor*>& inputA, std::vector<const
userArgs
,
m
.
size
()
*
sizeof
(
hipblaslt_ext
::
UserArguments
),
hipMemcpyHostToDevice
));
NVTE_CHECK_HIPBLASLT
(
groupedgemm
.
run
(
d_userArgs
,
stream
));
// NVTE_CHECK_HIPBLASLT(groupedgemm.initialize(heuristicResult[0].algo, workspace, false, stream));
// NVTE_CHECK_HIPBLASLT(groupedgemm.run(stream));
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment