Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
a13c52ad
Commit
a13c52ad
authored
Nov 08, 2025
by
wenjh
Browse files
Fix user args core dump in mt
parent
3a5755b1
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
1962 additions
and
2005 deletions
+1962
-2005
transformer_engine/common/gemm/rocm_gemm.cu
transformer_engine/common/gemm/rocm_gemm.cu
+1962
-2005
No files found.
transformer_engine/common/gemm/rocm_gemm.cu
View file @
a13c52ad
...
@@ -1352,82 +1352,41 @@ void hipblaslt_gemm(const Tensor* inputA, const Tensor* inputB, Tensor* outputD,
...
@@ -1352,82 +1352,41 @@ void hipblaslt_gemm(const Tensor* inputA, const Tensor* inputB, Tensor* outputD,
NVTE_CHECK_HIPBLASLT
(
hipblasLtMatmulDescDestroy
(
operationDesc
));
NVTE_CHECK_HIPBLASLT
(
hipblasLtMatmulDescDestroy
(
operationDesc
));
}
}
struct
HipBlasLtUserArgsDeleter
{
class
userArgsManager
{
void
operator
()(
hipblaslt_ext
::
UserArguments
*
ptr
)
const
noexcept
{
public:
hipFree
(
ptr
);
userArgsManager
()
{}
~
userArgsManager
()
{
// Release all userArgs when the manager is destroyed
for
(
auto
&
device_pair
:
userArgs_map_
)
{
hipFree
(
device_pair
.
second
);
// Only one userArgs per device
}
}
// Get a userArgs for the given device (creates if necessary)
hipblaslt_ext
::
UserArguments
*
get
(
int
device_id
,
size_t
size
)
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
mutex_
);
// Check if the userArgs for this device exists
auto
device_it
=
userArgs_map_
.
find
(
device_id
);
if
(
device_it
!=
userArgs_map_
.
end
())
{
return
device_it
->
second
;
}
}
// Create a new userArgs for this device if it doesn't exist
hipblaslt_ext
::
UserArguments
*
userArgs
;
NVTE_CHECK_CUDA
(
hipHostMalloc
(
&
userArgs
,
size
*
sizeof
(
hipblaslt_ext
::
UserArguments
)));
// Store the userArgs in the map for this device
userArgs_map_
[
device_id
]
=
userArgs
;
return
userArgs
;
}
private:
std
::
unordered_map
<
int
,
hipblaslt_ext
::
UserArguments
*>
userArgs_map_
;
// Map from device_id to hipblasHandle
std
::
mutex
mutex_
;
};
};
class
d_userArgsManager
{
using
HipBlasLtUserArgsPtr
=
std
::
unique_ptr
<
hipblaslt_ext
::
UserArguments
,
HipBlasLtUserArgsDeleter
>
;
public:
d_userArgsManager
()
{}
~
d_userArgsManager
()
{
inline
HipBlasLtUserArgsPtr
make_hipblaslt_user_args_ptr
(
size_t
size
,
bool
host
)
{
// Release all userArgs when the manager is destroyed
hipblaslt_ext
::
UserArguments
*
raw_ptr
=
nullptr
;
for
(
auto
&
device_pair
:
d_userArgs_map_
)
{
if
(
host
)
{
hipFree
(
device_pair
.
second
);
// Only one userArgs per device
NVTE_CHECK_CUDA
(
hipHostMalloc
(
&
raw_ptr
,
size
*
sizeof
(
hipblaslt_ext
::
UserArguments
)));
}
}
else
{
NVTE_CHECK_CUDA
(
hipMalloc
(
&
raw_ptr
,
size
*
sizeof
(
hipblaslt_ext
::
UserArguments
)));
}
}
return
HipBlasLtUserArgsPtr
(
raw_ptr
);
}
// Get a userArgs for the given device (creates if necessary)
inline
hipblaslt_ext
::
UserArguments
*
get_hipblaslt_user_args
(
size_t
size
,
bool
host
)
{
hipblaslt_ext
::
UserArguments
*
get
(
int
device_id
,
size_t
size
)
{
thread_local
static
std
::
unordered_map
<
size_t
,
HipBlasLtUserArgsPtr
>
host_userargs_cache
;
std
::
lock_guard
<
std
::
mutex
>
lock
(
mutex_
);
thread_local
static
std
::
unordered_map
<
size_t
,
HipBlasLtUserArgsPtr
>
device_userargs_cache
;
std
::
unordered_map
<
size_t
,
HipBlasLtUserArgsPtr
>&
user_args_cache
=
host
?
host_userargs_cache
:
device_userargs_cache
;
// Check if the userArgs for this device exists
auto
size_it
=
user_args_cache
.
find
(
size
);
auto
device_it
=
d_userArgs_map_
.
find
(
device_id
);
if
(
size_it
!=
user_args_cache
.
end
())
{
if
(
device_it
!=
d_userArgs_map_
.
end
())
{
return
size_it
->
second
.
get
();
return
device_it
->
second
;
}
}
else
// Create a new userArgs for this device if it doesn't exist
{
hipblaslt_ext
::
UserArguments
*
d_userArgs
;
HipBlasLtUserArgsPtr
user_args
=
make_hipblaslt_user_args_ptr
(
size
,
host
);
NVTE_CHECK_CUDA
(
hipMalloc
(
&
d_userArgs
,
size
*
sizeof
(
hipblaslt_ext
::
UserArguments
)));
hipblaslt_ext
::
UserArguments
*
raw_ptr
=
user_args
.
get
();
user_args_cache
[
size
]
=
std
::
move
(
user_args
);
// Store the userArgs in the map for this device
return
raw_ptr
;
d_userArgs_map_
[
device_id
]
=
d_userArgs
;
return
d_userArgs
;
}
}
}
private:
std
::
unordered_map
<
int
,
hipblaslt_ext
::
UserArguments
*>
d_userArgs_map_
;
// Map from device_id to hipblasHandle
std
::
mutex
mutex_
;
};
// Define a static userArgs manager
static
userArgsManager
UAManager
;
static
d_userArgsManager
d_UAManager
;
void
hipblaslt_groupedgemm
(
std
::
vector
<
const
Tensor
*>&
inputA
,
std
::
vector
<
const
Tensor
*>&
inputB
,
void
hipblaslt_groupedgemm
(
std
::
vector
<
const
Tensor
*>&
inputA
,
std
::
vector
<
const
Tensor
*>&
inputB
,
std
::
vector
<
Tensor
*>&
outputD
,
std
::
vector
<
int64_t
>&
m
,
std
::
vector
<
Tensor
*>&
outputD
,
std
::
vector
<
int64_t
>&
m
,
...
@@ -1438,10 +1397,8 @@ void hipblaslt_groupedgemm(std::vector<const Tensor*>& inputA, std::vector<const
...
@@ -1438,10 +1397,8 @@ void hipblaslt_groupedgemm(std::vector<const Tensor*>& inputA, std::vector<const
// Check compute_stream_offset valid.
// Check compute_stream_offset valid.
NVTE_CHECK
(
compute_stream_offset
>=
-
1
&&
compute_stream_offset
<
compute_num_streams
);
NVTE_CHECK
(
compute_stream_offset
>=
-
1
&&
compute_stream_offset
<
compute_num_streams
);
int
device_id
;
hipblaslt_ext
::
UserArguments
*
userArgs
=
get_hipblaslt_user_args
(
m
.
size
(),
true
);
hipGetDevice
(
&
device_id
);
hipblaslt_ext
::
UserArguments
*
d_userArgs
=
get_hipblaslt_user_args
(
m
.
size
(),
false
);
hipblaslt_ext
::
UserArguments
*
userArgs
=
UAManager
.
get
(
device_id
,
m
.
size
());
hipblaslt_ext
::
UserArguments
*
d_userArgs
=
d_UAManager
.
get
(
device_id
,
m
.
size
());
// hipblaslt_ext::UserArguments* userArgs;
// hipblaslt_ext::UserArguments* userArgs;
// NVTE_CHECK_CUDA(hipHostMalloc(&userArgs, m.size() * sizeof(hipblaslt_ext::UserArguments)));
// NVTE_CHECK_CUDA(hipHostMalloc(&userArgs, m.size() * sizeof(hipblaslt_ext::UserArguments)));
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment