Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
468d9d2d
Commit
468d9d2d
authored
Aug 06, 2025
by
yuguo
Browse files
Merge branch 'develop_v2.5' of
http://10.16.6.30/dcutoolkit/deeplearing/TransformerEngine
parents
eadb9886
ddfbdaf4
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
50 additions
and
13 deletions
+50
-13
transformer_engine/common/gemm/rocm_gemm.cu
transformer_engine/common/gemm/rocm_gemm.cu
+50
-13
No files found.
transformer_engine/common/gemm/rocm_gemm.cu
View file @
468d9d2d
...
...
@@ -1443,8 +1443,44 @@ private:
std
::
mutex
mutex_
;
};
class
d_userArgsManager
{
public:
d_userArgsManager
()
{}
~
d_userArgsManager
()
{
// Release all userArgs when the manager is destroyed
for
(
auto
&
device_pair
:
d_userArgs_map_
)
{
hipFree
(
device_pair
.
second
);
// Only one userArgs per device
}
}
// Get a userArgs for the given device (creates if necessary)
hipblaslt_ext
::
UserArguments
*
get
(
int
device_id
,
size_t
size
)
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
mutex_
);
// Check if the userArgs for this device exists
auto
device_it
=
d_userArgs_map_
.
find
(
device_id
);
if
(
device_it
!=
d_userArgs_map_
.
end
())
{
return
device_it
->
second
;
}
// Create a new userArgs for this device if it doesn't exist
hipblaslt_ext
::
UserArguments
*
d_userArgs
;
NVTE_CHECK_CUDA
(
hipMalloc
(
&
d_userArgs
,
size
*
sizeof
(
hipblaslt_ext
::
UserArguments
)));
// Store the userArgs in the map for this device
d_userArgs_map_
[
device_id
]
=
d_userArgs
;
return
d_userArgs
;
}
private:
std
::
unordered_map
<
int
,
hipblaslt_ext
::
UserArguments
*>
d_userArgs_map_
;
// Map from device_id to hipblasHandle
std
::
mutex
mutex_
;
};
// Define a static userArgs manager
// static userArgsManager UAManager;
static
userArgsManager
UAManager
;
static
d_userArgsManager
d_UAManager
;
void
hipblaslt_goupedgemm
(
std
::
vector
<
const
Tensor
*>&
inputA
,
std
::
vector
<
const
Tensor
*>&
inputB
,
std
::
vector
<
Tensor
*>&
outputD
,
std
::
vector
<
int64_t
>&
m
,
std
::
vector
<
int64_t
>&
n
,
std
::
vector
<
int64_t
>&
k
,
std
::
vector
<
int64_t
>&
b
,
hipblasOperation_t
transa
,
hipblasOperation_t
transb
,
...
...
@@ -1453,9 +1489,10 @@ void hipblaslt_goupedgemm(std::vector<const Tensor*>& inputA, std::vector<const
// Check compute_stream_offset valid.
NVTE_CHECK
(
compute_stream_offset
>=
-
1
&&
compute_stream_offset
<
compute_num_streams
);
// int device_id;
// hipGetDevice(&device_id);
// hipblaslt_ext::UserArguments* userArgs = UAManager.get(device_id, m.size());
int
device_id
;
hipGetDevice
(
&
device_id
);
hipblaslt_ext
::
UserArguments
*
userArgs
=
UAManager
.
get
(
device_id
,
m
.
size
());
hipblaslt_ext
::
UserArguments
*
d_userArgs
=
d_UAManager
.
get
(
device_id
,
m
.
size
());
// hipblaslt_ext::UserArguments* userArgs;
// NVTE_CHECK_CUDA(hipHostMalloc(&userArgs, m.size() * sizeof(hipblaslt_ext::UserArguments)));
...
...
@@ -1529,20 +1566,20 @@ void hipblaslt_goupedgemm(std::vector<const Tensor*>& inputA, std::vector<const
}
// Get the default values from the grouepdgemm object
//
groupedgemm.getDefaultValueForDeviceUserArguments(userArgs);
groupedgemm
.
getDefaultValueForDeviceUserArguments
(
userArgs
);
// Copy them to device memory
// hipblaslt_ext::UserArguments* d_userArgs;
// NVTE_CHECK_CUDA(hipMallocAsync(&d_userArgs, m.size() * sizeof(hipblaslt_ext::UserArguments), stream));
//
NVTE_CHECK_CUDA(hipMemcpyAsync(d_userArgs,
//
userArgs,
//
m.size() * sizeof(hipblaslt_ext::UserArguments),
//
hipMemcpyHostToDevice, stream));
NVTE_CHECK_CUDA
(
hipMemcpyAsync
(
d_userArgs
,
userArgs
,
m
.
size
()
*
sizeof
(
hipblaslt_ext
::
UserArguments
),
hipMemcpyHostToDevice
,
stream
));
// Make sure to initialize everytime the algo changes
//
NVTE_CHECK_HIPBLASLT(groupedgemm.initialize(heuristicResult[0].algo, workspace));
//
NVTE_CHECK_HIPBLASLT(groupedgemm.run(d_userArgs, stream));
NVTE_CHECK_HIPBLASLT
(
groupedgemm
.
initialize
(
heuristicResult
[
0
].
algo
,
workspace
,
false
,
stream
));
NVTE_CHECK_HIPBLASLT
(
groupedgemm
.
run
(
stream
));
NVTE_CHECK_HIPBLASLT
(
groupedgemm
.
initialize
(
heuristicResult
[
0
].
algo
,
workspace
));
NVTE_CHECK_HIPBLASLT
(
groupedgemm
.
run
(
d_userArgs
,
stream
));
//
NVTE_CHECK_HIPBLASLT(groupedgemm.initialize(heuristicResult[0].algo, workspace, false, stream));
//
NVTE_CHECK_HIPBLASLT(groupedgemm.run(stream));
// NVTE_CHECK_CUDA(hipFreeAsync(d_userArgs, stream));
// NVTE_CHECK_CUDA(hipFree(userArgs));
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment