Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
f4bd89eb
Commit
f4bd89eb
authored
Nov 12, 2025
by
wenjh
Browse files
Fix hipblaslt handle manage
parent
a13c52ad
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
8 additions
and
138 deletions
+8
-138
transformer_engine/common/gemm/rocm_gemm.cu
transformer_engine/common/gemm/rocm_gemm.cu
+8
-138
No files found.
transformer_engine/common/gemm/rocm_gemm.cu
View file @
f4bd89eb
...
...
@@ -465,106 +465,6 @@ transformer_engine::DType get_transformer_engine_dtype(const rocblas_datatype t)
namespace
{
static
class
HandlePool
{
public:
hipblasLtHandle_t
get
(
int
device_id
)
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
mt
);
if
(
pool
.
empty
())
{
int
device_count
=
0
;
NVTE_CHECK_CUDA
(
hipGetDeviceCount
(
&
device_count
));
pool
.
resize
(
device_count
);
return
nullptr
;
}
if
(
!
pool
[
device_id
].
empty
())
{
hipblasLtHandle_t
h
=
pool
[
device_id
].
front
();
pool
[
device_id
].
pop_front
();
return
h
;
}
return
nullptr
;
}
hipblasLtHandle_t
obtain
(
int
device_id
)
{
hipblasLtHandle_t
h
=
get
(
device_id
);
if
(
h
==
nullptr
)
{
NVTE_CHECK_HIPBLASLT
(
hipblasLtCreate
(
&
h
));
}
return
h
;
}
void
store
(
const
std
::
vector
<
hipblasLtHandle_t
>&
handles
)
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
mt
);
if
(
pool
.
empty
())
{
std
::
cout
<<
"[ERROR] Attempt to store handles to invalid pool"
<<
std
::
endl
;
}
for
(
unsigned
int
i
=
0
;
i
<
pool
.
size
();
i
++
)
{
if
(
handles
[
i
]
!=
nullptr
)
{
pool
[
i
].
push_front
(
handles
[
i
]);
}
}
}
~
HandlePool
()
{
#if DESTROY_HIPBLASLT_HANDLES_POOL
std
::
lock_guard
<
std
::
mutex
>
lock
(
mt
);
for
(
auto
&
hlist
:
pool
)
{
for
(
auto
&
h
:
hlist
)
{
hipblasLtDestroy
(
h
);
}
}
pool
.
clear
();
#endif
}
inline
size_t
get_size
()
const
{
return
pool
.
size
();
}
private:
std
::
mutex
mt
;
using
Pool
=
std
::
vector
<
std
::
forward_list
<
hipblasLtHandle_t
>>
;
// Order of destructors between thread_local and global is not actually guaranteed
// As a simple w/a make pool storage "leaky"
// Just do not destruct it and do not destroy hipbladLt handles
// Let OS deal with it on application exit
#if DESTROY_HIPBLASLT_HANDLES_POOL
Pool
pool
;
#else
Pool
&
pool
=
*
new
Pool
();
#endif
}
handle_pool
;
thread_local
static
class
HandleCache
{
public:
hipblasLtHandle_t
get
(
int
device_id
)
const
{
return
d
.
empty
()
?
nullptr
:
d
[
device_id
];
}
hipblasLtHandle_t
obtain
(
int
device_id
)
{
hipblasLtHandle_t
h
=
get
(
device_id
);
if
(
h
)
{
return
h
;
}
h
=
handle_pool
.
obtain
(
device_id
);
set
(
device_id
,
h
);
return
h
;
}
void
set
(
int
device_id
,
hipblasLtHandle_t
h
)
{
if
(
d
.
empty
())
{
d
.
resize
(
handle_pool
.
get_size
());
}
d
[
device_id
]
=
h
;
}
~
HandleCache
()
{
if
(
!
d
.
empty
())
{
handle_pool
.
store
(
d
);
}
}
private:
std
::
vector
<
hipblasLtHandle_t
>
d
;
}
cached_handles
;
class
csv_helper
{
public:
struct
start
{};
...
...
@@ -987,18 +887,12 @@ static inline int getIntEnv(const char* name, int defval, int minval) {
}
//namespace
/* Warning: only call once per device!
* When calling nvte_multi_stream_cublas_gemm with hipblaslt backend
* need to create multiple handles corresponding to compute_streams
* to avoid a handle be used by multi-streams concurrently.
*/
static
void
init_hipblaslt_handles
(
hipblasLtHandle_t
*
hipblaslt_handles
)
{
NVTE_CHECK
(
hipblaslt_handles
!=
nullptr
);
for
(
int
i
=
0
;
i
<
compute_num_streams
;
i
++
)
{
NVTE_CHECK_HIPBLASLT
(
hipblasLtCreate
(
&
hipblaslt_handles
[
i
]));
}
static
inline
void
CreateHipBlasLtHandle
(
hipblasLtHandle_t
*
handle
)
{
NVTE_CHECK_HIPBLASLT
(
hipblasLtCreate
(
handle
));
}
using
hipBlasLtHandleManager
=
detail
::
HandleManager
<
hipblasLtHandle_t
,
CreateHipBlasLtHandle
>
;
transformer_engine
::
DType
get_transformer_engine_dtype_from_hipblaslt_dtype
(
const
hipDataType
t
)
{
using
namespace
transformer_engine
;
switch
(
t
)
{
...
...
@@ -1018,8 +912,7 @@ void hipblaslt_gemm(const Tensor* inputA, const Tensor* inputB, Tensor* outputD,
int
ldb
,
int
ldd
,
hipblasOperation_t
transa
,
hipblasOperation_t
transb
,
bool
grad
,
void
*
workspace
,
size_t
workspaceSize
,
bool
accumulate
,
bool
use_split_accumulator
,
int
math_sm_count
,
int
m_split
,
int
n_split
,
bool
gemm_producer
,
const
Tensor
*
inputCounter
,
hipStream_t
stream
,
hipblasLtHandle_t
handle
)
{
bool
gemm_producer
,
const
Tensor
*
inputCounter
,
hipStream_t
stream
)
{
void
*
A
=
inputA
->
data
.
dptr
;
void
*
A_scale_inverse
=
inputA
->
scale_inv
.
dptr
;
float
*
A_scale_inverse_float
=
(
float
*
)(
inputA
->
scale_inv
.
dptr
);
...
...
@@ -1064,12 +957,7 @@ void hipblaslt_gemm(const Tensor* inputA, const Tensor* inputB, Tensor* outputD,
int
device_id
;
NVTE_CHECK_CUDA
(
hipGetDevice
(
&
device_id
));
if
(
handle
==
nullptr
)
{
handle
=
cached_handles
.
get
(
device_id
);
if
(
handle
==
nullptr
)
{
handle
=
cached_handles
.
obtain
(
device_id
);
}
}
hipblasLtHandle_t
handle
=
hipBlasLtHandleManager
::
Instance
().
GetHandle
();
hipblasLtMatmulDesc_t
operationDesc
=
nullptr
;
hipblasLtMatrixLayout_t
Adesc
=
nullptr
,
Bdesc
=
nullptr
,
Cdesc
=
nullptr
,
Ddesc
=
nullptr
;
...
...
@@ -1403,15 +1291,7 @@ void hipblaslt_groupedgemm(std::vector<const Tensor*>& inputA, std::vector<const
// hipblaslt_ext::UserArguments* userArgs;
// NVTE_CHECK_CUDA(hipHostMalloc(&userArgs, m.size() * sizeof(hipblaslt_ext::UserArguments)));
hipblasLtHandle_t
handle
=
nullptr
;
if
(
compute_stream_offset
!=
-
1
)
{
// Init hipblaslt handles (once, globally)
static
std
::
once_flag
init_flag
;
static
hipblasLtHandle_t
hipblaslt_handles
[
compute_num_streams
];
std
::
call_once
(
init_flag
,
init_hipblaslt_handles
,
hipblaslt_handles
);
handle
=
hipblaslt_handles
[
compute_stream_offset
];
}
hipblasLtHandle_t
handle
=
hipBlasLtHandleManager
::
Instance
().
GetHandle
();
const
hipDataType
A_type
=
get_hipblaslt_dtype
(
inputA
[
0
]
->
data
.
dtype
);
const
hipDataType
B_type
=
get_hipblaslt_dtype
(
inputB
[
0
]
->
data
.
dtype
);
...
...
@@ -1929,20 +1809,10 @@ void cublas_gemm(const Tensor* inputA, const Tensor* inputB, Tensor* outputD,
// Check compute_stream_offset valid.
NVTE_CHECK
(
compute_stream_offset
>=
-
1
&&
compute_stream_offset
<
compute_num_streams
);
hipblasLtHandle_t
handle
=
nullptr
;
if
(
compute_stream_offset
!=
-
1
)
{
// Init hipblaslt handles (once, globally)
static
std
::
once_flag
init_flag
;
static
hipblasLtHandle_t
hipblaslt_handles
[
compute_num_streams
];
std
::
call_once
(
init_flag
,
init_hipblaslt_handles
,
hipblaslt_handles
);
handle
=
hipblaslt_handles
[
compute_stream_offset
];
}
hipblaslt_gemm
(
inputA
,
inputB
,
outputD
,
inputBias
,
outputPreGelu
,
m
,
n
,
k
,
lda
,
ldb
,
ldd
,
(
transa
)
?
HIPBLAS_OP_T
:
HIPBLAS_OP_N
,
(
transb
)
?
HIPBLAS_OP_T
:
HIPBLAS_OP_N
,
grad
,
workspace
,
workspaceSize
,
accumulate
,
use_split_accumulator
,
math_sm_count
,
m_split
,
n_split
,
gemm_producer
,
inputCounter
,
stream
,
handle
);
m_split
,
n_split
,
gemm_producer
,
inputCounter
,
stream
);
return
;
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment