Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
47077129
"vscode:/vscode.git/clone" did not exist on "d198770e96a6eac4d7b6233e6f411e339b32ce3d"
Commit
47077129
authored
Oct 16, 2025
by
yuguo
Browse files
[DCU] remove redundant gemm
parent
aa62d24c
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
6 additions
and
494 deletions
+6
-494
transformer_engine/common/gemm/cublaslt_gemm.cu
transformer_engine/common/gemm/cublaslt_gemm.cu
+3
-81
transformer_engine/common/gemm/rocm_gemm.cu
transformer_engine/common/gemm/rocm_gemm.cu
+0
-405
transformer_engine/common/include/transformer_engine/gemm.h
transformer_engine/common/include/transformer_engine/gemm.h
+1
-6
transformer_engine/pytorch/csrc/extensions/gemm.cpp
transformer_engine/pytorch/csrc/extensions/gemm.cpp
+2
-2
No files found.
transformer_engine/common/gemm/cublaslt_gemm.cu
View file @
47077129
...
@@ -1166,82 +1166,13 @@ void nvte_cublas_batchgemm(const NVTETensor A, const NVTETensor B, NVTETensor D,
...
@@ -1166,82 +1166,13 @@ void nvte_cublas_batchgemm(const NVTETensor A, const NVTETensor B, NVTETensor D,
stream
);
stream
);
}
}
// add for batchgemm
void
nvte_cublas_batchgemm_v2
(
const
NVTETensor
A
,
const
NVTETensor
B
,
NVTETensor
D
,
const
NVTETensor
bias
,
NVTETensor
pre_gelu_out
,
bool
transa
,
bool
transb
,
bool
grad
,
NVTETensor
workspace
,
bool
accumulate
,
bool
use_split_accumulator
,
int
math_sm_count
,
int
batch_count
,
cudaStream_t
stream
)
{
NVTE_API_CALL
(
nvte_cublas_batchgemm_v2
);
using
namespace
transformer_engine
;
const
Tensor
*
inputA
=
convertNVTETensorCheck
(
A
);
const
Tensor
*
inputB
=
convertNVTETensorCheck
(
B
);
Tensor
*
outputD
=
convertNVTETensor
(
D
);
const
Tensor
*
biasTensor
=
convertNVTETensor
(
bias
);
Tensor
*
outputGelu
=
convertNVTETensor
(
pre_gelu_out
);
Tensor
*
wspace
=
convertNVTETensor
(
workspace
);
if
((
biasTensor
->
data
.
dptr
!=
nullptr
)
||
(
outputGelu
->
data
.
dptr
!=
nullptr
))
{
NVTE_ERROR
(
"MOE batchgemm not surpport bias or gelu."
);
}
int
m
,
n
,
k
;
if
(
!
transa
&&
transb
)
{
// for NT
m
=
transa
?
inputA
->
data
.
shape
[
0
]
/
batch_count
:
inputA
->
data
.
shape
[
1
];
k
=
transa
?
inputA
->
data
.
shape
[
1
]
:
inputA
->
data
.
shape
[
0
]
/
batch_count
;
n
=
transb
?
inputB
->
data
.
shape
[
1
]
:
inputB
->
data
.
shape
[
0
]
/
batch_count
;
}
else
if
(
transa
&&
!
transb
){
// for TN
m
=
transa
?
inputA
->
data
.
shape
[
0
]
/
batch_count
:
inputA
->
data
.
shape
[
1
];
k
=
transa
?
inputA
->
data
.
shape
[
1
]
:
inputA
->
data
.
shape
[
0
]
/
batch_count
;
n
=
transb
?
inputB
->
data
.
shape
[
1
]
:
inputB
->
data
.
shape
[
0
]
/
batch_count
;
}
else
if
(
!
transa
&&
!
transb
){
// for NN
m
=
transa
?
inputA
->
data
.
shape
[
0
]
/
batch_count
:
inputA
->
data
.
shape
[
1
];
k
=
transa
?
inputA
->
data
.
shape
[
1
]
:
inputA
->
data
.
shape
[
0
]
/
batch_count
;
n
=
transb
?
inputB
->
data
.
shape
[
1
]
:
inputB
->
data
.
shape
[
0
]
/
batch_count
;
}
int
lda
,
ldb
,
ldd
;
if
(
transa
&&
!
transb
)
{
// TN
lda
=
k
;
ldb
=
k
;
ldd
=
m
;
}
else
if
(
!
transa
&&
!
transb
)
{
// NN
lda
=
m
;
ldb
=
k
;
ldd
=
m
;
}
else
if
(
!
transa
&&
transb
)
{
// NT
lda
=
m
;
ldb
=
n
;
ldd
=
m
;
}
else
{
// TT
NVTE_ERROR
(
"TT layout not allowed."
);
}
hipblas_batchgemm
(
inputA
,
inputB
,
outputD
,
biasTensor
,
outputGelu
,
m
,
n
,
k
,
lda
,
ldb
,
ldd
,
(
transa
)
?
HIPBLAS_OP_T
:
HIPBLAS_OP_N
,
(
transb
)
?
HIPBLAS_OP_T
:
HIPBLAS_OP_N
,
grad
,
wspace
->
data
.
dptr
,
wspace
->
data
.
shape
[
0
],
accumulate
,
use_split_accumulator
,
math_sm_count
,
0
,
0
,
false
,
nullptr
,
batch_count
,
stream
);
}
// add for batchgemm
// add for batchgemm
void
nvte_cublas_batchgemm_
v3
(
const
NVTETensor
A
,
const
NVTETensor
B
,
const
NVTETensor
A_scales
,
const
NVTETensor
B_scales
,
NVTETensor
D
,
const
NVTETensor
bias
,
void
nvte_cublas_batchgemm_
tensorwise_int8
(
const
NVTETensor
A
,
const
NVTETensor
B
,
const
NVTETensor
A_scales
,
const
NVTETensor
B_scales
,
NVTETensor
D
,
const
NVTETensor
bias
,
NVTETensor
pre_gelu_out
,
bool
transa
,
bool
transb
,
bool
grad
,
NVTETensor
pre_gelu_out
,
bool
transa
,
bool
transb
,
bool
grad
,
NVTETensor
workspace
,
bool
accumulate
,
bool
use_split_accumulator
,
NVTETensor
workspace
,
bool
accumulate
,
bool
use_split_accumulator
,
int
math_sm_count
,
int
batch_count
,
cudaStream_t
stream
)
{
int
math_sm_count
,
int
batch_count
,
cudaStream_t
stream
)
{
NVTE_API_CALL
(
nvte_cublas_batchgemm_
v3
);
NVTE_API_CALL
(
nvte_cublas_batchgemm_
tensorwise_int8
);
using
namespace
transformer_engine
;
using
namespace
transformer_engine
;
const
Tensor
*
inputA
=
convertNVTETensorCheck
(
A
);
const
Tensor
*
inputA
=
convertNVTETensorCheck
(
A
);
const
Tensor
*
inputB
=
convertNVTETensorCheck
(
B
);
const
Tensor
*
inputB
=
convertNVTETensorCheck
(
B
);
...
@@ -1297,16 +1228,7 @@ void nvte_cublas_batchgemm_v3(const NVTETensor A, const NVTETensor B, const NVTE
...
@@ -1297,16 +1228,7 @@ void nvte_cublas_batchgemm_v3(const NVTETensor A, const NVTETensor B, const NVTE
handle
=
hipblaslt_handles
[
0
];
handle
=
hipblaslt_handles
[
0
];
hipblaslt_batchgemm_tensorwise_int8
(
inputA
,
inputB
,
inputA_scales
,
inputB_scales
,
outputD
,
biasTensor
,
outputGelu
,
NVTE_ERROR
(
"Remove nvte_cublas_batchgemm_tensorwise_int8 for now."
);
m
,
n
,
k
,
lda
,
ldb
,
ldd
,
(
transa
)
?
HIPBLAS_OP_T
:
HIPBLAS_OP_N
,
(
transb
)
?
HIPBLAS_OP_T
:
HIPBLAS_OP_N
,
grad
,
wspace
->
data
.
dptr
,
wspace
->
data
.
shape
[
0
],
accumulate
,
use_split_accumulator
,
math_sm_count
,
0
,
0
,
false
,
nullptr
,
batch_count
,
stream
,
handle
);
}
}
#endif
#endif
...
...
transformer_engine/common/gemm/rocm_gemm.cu
View file @
47077129
...
@@ -1352,411 +1352,6 @@ void hipblaslt_gemm(const Tensor* inputA, const Tensor* inputB, Tensor* outputD,
...
@@ -1352,411 +1352,6 @@ void hipblaslt_gemm(const Tensor* inputA, const Tensor* inputB, Tensor* outputD,
NVTE_CHECK_HIPBLASLT
(
hipblasLtMatmulDescDestroy
(
operationDesc
));
NVTE_CHECK_HIPBLASLT
(
hipblasLtMatmulDescDestroy
(
operationDesc
));
}
}
void
hipblaslt_batchgemm_tensorwise_int8
(
const
Tensor
*
inputA
,
const
Tensor
*
inputB
,
const
Tensor
*
inputA_scales
,
const
Tensor
*
inputB_scales
,
Tensor
*
outputD
,
const
Tensor
*
inputBias
,
Tensor
*
outputPreGelu
,
int
m
,
int
n
,
int
k
,
int
lda
,
int
ldb
,
int
ldd
,
hipblasOperation_t
transa
,
hipblasOperation_t
transb
,
bool
grad
,
void
*
workspace
,
size_t
workspaceSize
,
bool
accumulate
,
bool
use_split_accumulator
,
int
math_sm_count
,
int
m_split
,
int
n_split
,
bool
gemm_producer
,
const
Tensor
*
inputCounter
,
size_t
batch_count
,
hipStream_t
stream
,
hipblasLtHandle_t
handle
)
{
void
*
A
=
inputA
->
data
.
dptr
;
void
*
A_scale_inverse
=
inputA_scales
->
data
.
dptr
;
float
*
A_scale_inverse_float
=
(
float
*
)(
inputA_scales
->
data
.
dptr
);
void
*
B
=
inputB
->
data
.
dptr
;
void
*
B_scale_inverse
=
inputB_scales
->
data
.
dptr
;
float
*
B_scale_inverse_float
=
(
float
*
)(
inputB_scales
->
data
.
dptr
);
void
*
D
=
outputD
->
data
.
dptr
;
void
*
bias_ptr
=
inputBias
->
data
.
dptr
;
const
bool
bias
=
bias_ptr
!=
nullptr
;
void
*
pre_gelu_out
=
outputPreGelu
->
data
.
dptr
;
const
bool
gelu
=
pre_gelu_out
!=
nullptr
;
const
bool
use_fp8
=
is_fp8_dtype
(
inputA
->
data
.
dtype
)
||
is_fp8_dtype
(
inputB
->
data
.
dtype
);
const
bool
use_int8
=
is_int8_dtype
(
inputA
->
data
.
dtype
)
||
is_int8_dtype
(
inputB
->
data
.
dtype
);
const
hipDataType
A_type
=
get_hipblaslt_dtype
(
inputA
->
data
.
dtype
);
const
hipDataType
B_type
=
get_hipblaslt_dtype
(
inputB
->
data
.
dtype
);
const
hipDataType
D_type
=
get_hipblaslt_dtype
(
outputD
->
data
.
dtype
);
const
hipDataType
bias_type
=
get_hipblaslt_dtype
(
inputBias
->
data
.
dtype
);
NVTE_CHECK
(
!
is_fp8_dtype
(
inputA
->
data
.
dtype
)
||
A_scale_inverse
!=
nullptr
,
"FP8 input to GEMM requires inverse of scale!"
);
NVTE_CHECK
(
!
is_fp8_dtype
(
inputB
->
data
.
dtype
)
||
B_scale_inverse
!=
nullptr
,
"FP8 input to GEMM requires inverse of scale!"
);
NVTE_CHECK
(
!
is_int8_dtype
(
inputA
->
data
.
dtype
)
||
A_scale_inverse
!=
nullptr
,
"INT8 input to GEMM requires inverse of scale!"
);
NVTE_CHECK
(
!
is_int8_dtype
(
inputB
->
data
.
dtype
)
||
B_scale_inverse
!=
nullptr
,
"INT8 input to GEMM requires inverse of scale!"
);
bool
tensorwise_int8
=
0
;;
const
char
*
NVTE_INT8_SIM_FP8_TENSORWISE
=
std
::
getenv
(
"NVTE_INT8_SIM_FP8_TENSORWISE"
);
if
(
NVTE_INT8_SIM_FP8_TENSORWISE
!=
nullptr
&&
NVTE_INT8_SIM_FP8_TENSORWISE
[
0
]
==
'1'
&&
use_int8
)
tensorwise_int8
=
1
;
// check consistency of arguments:
// if fp8 is desired, context cannot be null
// fp8 + gelu fusion + fp8 aux is unavailable right now.
if
(
use_fp8
||
use_int8
)
{
NVTE_CHECK
(
!
gelu
,
"fp8 gemm + gelu fusion is unavailable right now!"
);
}
float
one
=
1.0
;
float
zero
=
0.0
;
float
beta
=
(
accumulate
)
?
one
:
zero
;
int
device_id
;
NVTE_CHECK_CUDA
(
hipGetDevice
(
&
device_id
));
if
(
handle
==
nullptr
)
{
handle
=
cached_handles
.
get
(
device_id
);
if
(
handle
==
nullptr
)
{
handle
=
cached_handles
.
obtain
(
device_id
);
}
}
hipblasLtMatmulDesc_t
operationDesc
=
nullptr
;
hipblasLtMatrixLayout_t
Adesc
=
nullptr
,
Bdesc
=
nullptr
,
Cdesc
=
nullptr
,
Ddesc
=
nullptr
;
hipblasLtMatmulPreference_t
preference
=
nullptr
;
hipblasLtEpilogue_t
epilogue
=
HIPBLASLT_EPILOGUE_DEFAULT
;
int64_t
ld_gelumat
=
(
int64_t
)
ldd
;
// default to tf32 except for e5m2 inputs where the config is not supported
hipblasComputeType_t
gemm_compute_type
=
HIPBLAS_COMPUTE_32F
;
// Create matrix descriptors. Not setting any extra attributes.
NVTE_CHECK_HIPBLASLT
(
hipblasLtMatrixLayoutCreate
(
&
Adesc
,
A_type
,
transa
==
HIPBLAS_OP_N
?
m
:
k
,
transa
==
HIPBLAS_OP_N
?
k
:
m
,
lda
));
NVTE_CHECK_HIPBLASLT
(
hipblasLtMatrixLayoutCreate
(
&
Bdesc
,
B_type
,
transb
==
HIPBLAS_OP_N
?
k
:
n
,
transb
==
HIPBLAS_OP_N
?
n
:
k
,
ldb
));
NVTE_CHECK_HIPBLASLT
(
hipblasLtMatrixLayoutCreate
(
&
Ddesc
,
D_type
,
m
,
n
,
ldd
));
if
(
tensorwise_int8
)
{
size_t
strideA
=
m
*
k
;
size_t
strideB
=
k
*
n
;
size_t
strideD
=
m
*
n
;
hipblasLtMatrixLayoutSetAttribute
(
Adesc
,
HIPBLASLT_MATRIX_LAYOUT_BATCH_COUNT
,
&
batch_count
,
sizeof
(
int32_t
));
hipblasLtMatrixLayoutSetAttribute
(
Adesc
,
HIPBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET
,
&
strideA
,
sizeof
(
int64_t
));
hipblasLtMatrixLayoutSetAttribute
(
Bdesc
,
HIPBLASLT_MATRIX_LAYOUT_BATCH_COUNT
,
&
batch_count
,
sizeof
(
int32_t
));
hipblasLtMatrixLayoutSetAttribute
(
Bdesc
,
HIPBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET
,
&
strideB
,
sizeof
(
int64_t
));
hipblasLtMatrixLayoutSetAttribute
(
Ddesc
,
HIPBLASLT_MATRIX_LAYOUT_BATCH_COUNT
,
&
batch_count
,
sizeof
(
int32_t
));
hipblasLtMatrixLayoutSetAttribute
(
Ddesc
,
HIPBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET
,
&
strideD
,
sizeof
(
int64_t
));
NVTE_CHECK_HIPBLASLT
(
hipblasLtMatmulDescCreate
(
&
operationDesc
,
gemm_compute_type
,
HIP_R_32F
));
}
else
{
NVTE_CHECK_HIPBLASLT
(
hipblasLtMatmulDescCreate
(
&
operationDesc
,
gemm_compute_type
,
HIP_R_32F
));
}
NVTE_CHECK_HIPBLASLT
(
hipblasLtMatmulDescSetAttribute
(
operationDesc
,
HIPBLASLT_MATMUL_DESC_TRANSA
,
&
transa
,
sizeof
(
transa
)));
NVTE_CHECK_HIPBLASLT
(
hipblasLtMatmulDescSetAttribute
(
operationDesc
,
HIPBLASLT_MATMUL_DESC_TRANSB
,
&
transb
,
sizeof
(
transb
)));
// set fp8 attributes -- input and output types should already be set to fp8 as appropriate
// Note: gelu fusion isn't available right now, and we don't need
// amax(D) either (next op is high precision).
if
(
use_fp8
)
{
// Split accumulator.
const
int8_t
fastAccuMode
=
(
use_split_accumulator
)
?
0
:
1
;
/*
NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescSetAttribute(operationDesc,
HIPBLASLT_MATMUL_DESC_FAST_ACCUM, //TODO: We don't have fast accum mode yet
&fastAccuMode,
sizeof(fastAccuMode)));
*/
NVTE_CHECK_HIPBLASLT
(
hipblasLtMatmulDescSetAttribute
(
operationDesc
,
HIPBLASLT_MATMUL_DESC_A_SCALE_POINTER
,
&
A_scale_inverse
,
sizeof
(
A_scale_inverse
)));
NVTE_CHECK_HIPBLASLT
(
hipblasLtMatmulDescSetAttribute
(
operationDesc
,
HIPBLASLT_MATMUL_DESC_B_SCALE_POINTER
,
&
B_scale_inverse
,
sizeof
(
B_scale_inverse
)));
if
(
bias
)
{
NVTE_CHECK_HIPBLASLT
(
hipblasLtMatmulDescSetAttribute
(
operationDesc
,
HIPBLASLT_MATMUL_DESC_BIAS_DATA_TYPE
,
&
bias_type
,
sizeof
(
bias_type
)));
}
}
if
(
tensorwise_int8
)
{
NVTE_CHECK_HIPBLASLT
(
hipblasLtMatmulDescSetAttribute
(
operationDesc
,
HIPBLASLT_MATMUL_DESC_A_SCALE_POINTER
,
(
void
*
)
&
A_scale_inverse_float
,
sizeof
(
void
*
)));
NVTE_CHECK_HIPBLASLT
(
hipblasLtMatmulDescSetAttribute
(
operationDesc
,
HIPBLASLT_MATMUL_DESC_B_SCALE_POINTER
,
(
void
*
)
&
B_scale_inverse_float
,
sizeof
(
void
*
)));
if
(
bias
)
{
NVTE_CHECK
(
false
,
"tensorwise_int8 not surpport bias!"
);
}
}
if
(
bias
&&
gelu
)
{
if
(
grad
)
{
epilogue
=
HIPBLASLT_EPILOGUE_DGELU_BGRAD
;
}
else
{
epilogue
=
HIPBLASLT_EPILOGUE_GELU_AUX_BIAS
;
}
NVTE_CHECK_HIPBLASLT
(
hipblasLtMatmulDescSetAttribute
(
operationDesc
,
HIPBLASLT_MATMUL_DESC_BIAS_POINTER
,
&
bias_ptr
,
sizeof
(
bias_ptr
)));
NVTE_CHECK_HIPBLASLT
(
hipblasLtMatmulDescSetAttribute
(
operationDesc
,
HIPBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER
,
&
pre_gelu_out
,
sizeof
(
pre_gelu_out
)));
NVTE_CHECK_HIPBLASLT
(
hipblasLtMatmulDescSetAttribute
(
operationDesc
,
HIPBLASLT_MATMUL_DESC_EPILOGUE_AUX_LD
,
&
ld_gelumat
,
sizeof
(
ld_gelumat
)));
}
else
if
(
bias
)
{
if
(
grad
)
{
// grad output is always input B
epilogue
=
HIPBLASLT_EPILOGUE_BGRADB
;
}
else
{
epilogue
=
HIPBLASLT_EPILOGUE_BIAS
;
}
NVTE_CHECK_HIPBLASLT
(
hipblasLtMatmulDescSetAttribute
(
operationDesc
,
HIPBLASLT_MATMUL_DESC_BIAS_POINTER
,
&
bias_ptr
,
sizeof
(
bias_ptr
)));
}
else
if
(
gelu
)
{
if
(
grad
)
{
epilogue
=
HIPBLASLT_EPILOGUE_DGELU
;
}
else
{
epilogue
=
HIPBLASLT_EPILOGUE_GELU_AUX
;
}
NVTE_CHECK_HIPBLASLT
(
hipblasLtMatmulDescSetAttribute
(
operationDesc
,
HIPBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER
,
&
pre_gelu_out
,
sizeof
(
pre_gelu_out
)));
NVTE_CHECK_HIPBLASLT
(
hipblasLtMatmulDescSetAttribute
(
operationDesc
,
HIPBLASLT_MATMUL_DESC_EPILOGUE_AUX_LD
,
&
ld_gelumat
,
sizeof
(
ld_gelumat
)));
}
NVTE_CHECK_HIPBLASLT
(
hipblasLtMatmulDescSetAttribute
(
operationDesc
,
HIPBLASLT_MATMUL_DESC_EPILOGUE
,
&
epilogue
,
sizeof
(
epilogue
)));
GemmAlgoCache
::
Key
gemm_cfg
(
algoCache
.
device_cap
(
device_id
),
A_type
,
B_type
,
D_type
,
use_fp8
?
bias_type
:
(
hipDataType
)
-
1
,
m
,
n
,
k
,
lda
,
ldb
,
ldd
,
transa
,
transb
,
epilogue
);
GemmAlgoCache
::
Algo
cached_algo
;
if
(
algoCache
.
find
(
gemm_cfg
,
workspaceSize
,
cached_algo
)
==
0
||
!
cached_algo
.
algo
.
has_value
())
{
int
firstAlgo
=
getIntEnv
(
"TE_HIPBLASLT_ALGO_SELECTION"
,
0
,
0
);
int
tuneLoopCount
=
getIntEnv
(
"TE_HIPBLASLT_TUNING_RUN_COUNT"
,
0
,
0
);
int
algoTuneCount
=
1
;
std
::
vector
<
hipblasLtMatmulHeuristicResult_t
>
algoArr
;
bool
logTuning
=
getIntEnv
(
"TE_HIPBLASLT_LOG_TUNING"
,
0
,
0
)
!=
0
;
if
(
tuneLoopCount
)
{
/* HIPBLASLT may return hundreds of algos for some configs
* Limit amount by default. User may override with env
*/
static
const
int
defaultAlgoCount
=
16
;
algoTuneCount
=
getIntEnv
(
"TE_HIPBLASLT_TUNING_ALGO_COUNT"
,
defaultAlgoCount
,
1
);
}
algoTuneCount
+=
firstAlgo
;
int
algoTotalCount
=
cached_algo
.
hasId
()
?
std
::
max
(
algoTuneCount
,
(
cached_algo
.
index
+
1
))
:
algoTuneCount
;
algoArr
.
resize
(
algoTotalCount
);
NVTE_CHECK_HIPBLASLT
(
hipblasLtMatmulPreferenceCreate
(
&
preference
));
NVTE_CHECK_HIPBLASLT
(
hipblasLtMatmulPreferenceSetAttribute
(
preference
,
HIPBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES
,
&
workspaceSize
,
sizeof
(
workspaceSize
)));
NVTE_CHECK_HIPBLASLT
(
hipblasLtMatmulAlgoGetHeuristic
(
handle
,
operationDesc
,
Adesc
,
Bdesc
,
Ddesc
,
Ddesc
,
preference
,
algoTotalCount
,
algoArr
.
data
(),
&
algoTotalCount
));
algoArr
.
resize
(
algoTotalCount
);
NVTE_CHECK_HIPBLASLT
(
hipblasLtMatmulPreferenceDestroy
(
preference
));
//If cached algo exists in persistent storage we just need to find matching hipblasLtMatmulAlgo_t
if
(
cached_algo
.
hasId
())
{
int
idx
=
(
cached_algo
.
index
<
algoTotalCount
)
?
cached_algo
.
index
:
0
;
for
(
int
i
=
0
;
i
<
algoTotalCount
;
i
++
)
{
const
auto
&
algo
=
algoArr
[
idx
];
if
(
algo
.
state
==
HIPBLAS_STATUS_SUCCESS
)
{
if
(
cached_algo
.
algoId
==
cached_algo
.
getAlgoId
(
algo
.
algo
))
{
cached_algo
.
algo
=
algo
.
algo
;
if
(
algo
.
workspaceSize
!=
cached_algo
.
ws_size_min
||
idx
!=
cached_algo
.
index
)
{
cached_algo
.
ws_size_min
=
algo
.
workspaceSize
;
cached_algo
.
index
=
idx
;
algoCache
.
store
(
gemm_cfg
,
cached_algo
);
}
break
;
}
}
idx
=
(
idx
+
1
)
%
algoTotalCount
;
}
if
(
logTuning
&&
!
cached_algo
.
algo
.
has_value
())
{
std
::
cout
<<
"[WARNING] Cannot find cached algoId "
<<
cached_algo
.
algoId
<<
" in hipBLASLt results"
<<
std
::
endl
;
}
}
//No suitable entry in autotune cache or could not find matched algo in hipBLASLt results
if
(
!
cached_algo
.
algo
.
has_value
())
{
int
bestAlgo
=
-
1
;
algoTuneCount
=
std
::
min
(
algoTuneCount
,
algoTotalCount
);
if
(
tuneLoopCount
>
0
)
{
if
(
logTuning
)
std
::
cout
<<
"[INFO] Perform hipBLASLt algo selection on GPU"
<<
device_id
<<
" in range ["
<<
firstAlgo
<<
"-"
<<
(
algoTuneCount
-
1
)
<<
"] with "
<<
tuneLoopCount
<<
" loops "
<<
std
::
endl
;
NVTE_CHECK_CUDA
(
hipStreamSynchronize
(
stream
));
hipStream_t
profilingStream
;
NVTE_CHECK_CUDA
(
hipStreamCreateWithFlags
(
&
profilingStream
,
hipStreamNonBlocking
));
using
tuning_clock
=
std
::
chrono
::
steady_clock
;
tuning_clock
::
now
();
//the first call takes little longer so do it outside the loop
tuning_clock
::
duration
bestTime
=
tuning_clock
::
duration
::
max
();
for
(
int
algo
=
firstAlgo
;
algo
<
algoTuneCount
;
algo
++
)
{
if
(
algoArr
[
algo
].
state
!=
HIPBLAS_STATUS_SUCCESS
)
{
continue
;
}
// Warm-up call
NVTE_CHECK_HIPBLASLT
(
hipblasLtMatmul
(
handle
,
operationDesc
,
static_cast
<
const
void
*>
(
&
one
),
/* alpha */
A
,
/* A */
Adesc
,
B
,
/* B */
Bdesc
,
static_cast
<
const
void
*>
(
&
beta
),
/* beta */
D
,
/* C */
Ddesc
,
D
,
/* D */
Ddesc
,
&
algoArr
[
algo
].
algo
,
/* algo */
workspace
,
/* workspace */
workspaceSize
,
profilingStream
));
/* stream */
NVTE_CHECK_CUDA
(
hipStreamSynchronize
(
profilingStream
));
//Profiling loop
tuning_clock
::
time_point
startTime
=
tuning_clock
::
now
();
for
(
int
loop
=
0
;
loop
<
tuneLoopCount
;
loop
++
)
{
NVTE_CHECK_HIPBLASLT
(
hipblasLtMatmul
(
handle
,
operationDesc
,
static_cast
<
const
void
*>
(
&
one
),
/* alpha */
A
,
/* A */
Adesc
,
B
,
/* B */
Bdesc
,
static_cast
<
const
void
*>
(
&
beta
),
/* beta */
D
,
/* C */
Ddesc
,
D
,
/* D */
Ddesc
,
&
algoArr
[
algo
].
algo
,
/* algo */
workspace
,
/* workspace */
workspaceSize
,
profilingStream
));
/* stream */
}
NVTE_CHECK_CUDA
(
hipStreamSynchronize
(
profilingStream
));
tuning_clock
::
duration
algoTime
=
tuning_clock
::
now
()
-
startTime
;
if
(
algoTime
<
bestTime
)
{
bestAlgo
=
algo
;
bestTime
=
algoTime
;
}
}
NVTE_CHECK_CUDA
(
hipStreamDestroy
(
profilingStream
));
if
(
bestAlgo
>=
0
)
{
if
(
logTuning
)
std
::
cout
<<
"[INFO] Select hipBLASLt algo "
<<
bestAlgo
<<
" with time "
<<
std
::
chrono
::
duration_cast
<
std
::
chrono
::
nanoseconds
>
(
bestTime
).
count
()
/
tuneLoopCount
<<
" ns"
<<
std
::
endl
;
}
}
else
if
(
firstAlgo
<
algoTuneCount
)
{
bestAlgo
=
firstAlgo
;
}
if
(
bestAlgo
<
0
)
{
NVTE_CHECK_HIPBLASLT
(
hipblasLtMatrixLayoutDestroy
(
Ddesc
));
NVTE_CHECK_HIPBLASLT
(
hipblasLtMatrixLayoutDestroy
(
Bdesc
));
NVTE_CHECK_HIPBLASLT
(
hipblasLtMatrixLayoutDestroy
(
Adesc
));
NVTE_CHECK_HIPBLASLT
(
hipblasLtMatmulDescDestroy
(
operationDesc
));
throw
std
::
runtime_error
(
"Unable to find any suitable algorithms"
);
}
cached_algo
.
algo
=
algoArr
[
bestAlgo
].
algo
;
cached_algo
.
index
=
bestAlgo
;
cached_algo
.
algoId
=
cached_algo
.
getAlgoId
(
algoArr
[
bestAlgo
].
algo
);
cached_algo
.
ws_size_min
=
algoArr
[
bestAlgo
].
workspaceSize
;
cached_algo
.
ws_size_max
=
workspaceSize
;
if
(
logTuning
)
std
::
cout
<<
"[INFO] Use hipBLASLt algo ["
<<
bestAlgo
<<
"] "
<<
cached_algo
.
algoId
<<
std
::
endl
;
algoCache
.
store
(
gemm_cfg
,
cached_algo
);
}
}
// D = alpha * (A * B) + beta * C
NVTE_CHECK_HIPBLASLT
(
hipblasLtMatmul
(
handle
,
operationDesc
,
static_cast
<
const
void
*>
(
&
one
),
/* alpha */
A
,
/* A */
Adesc
,
B
,
/* B */
Bdesc
,
static_cast
<
const
void
*>
(
&
beta
),
/* beta */
D
,
/* C */
Ddesc
,
D
,
/* D */
Ddesc
,
&
cached_algo
.
algo
.
value
(),
/* algo */
workspace
,
/* workspace */
workspaceSize
,
stream
));
/* stream */
NVTE_CHECK_HIPBLASLT
(
hipblasLtMatrixLayoutDestroy
(
Ddesc
));
NVTE_CHECK_HIPBLASLT
(
hipblasLtMatrixLayoutDestroy
(
Bdesc
));
NVTE_CHECK_HIPBLASLT
(
hipblasLtMatrixLayoutDestroy
(
Adesc
));
NVTE_CHECK_HIPBLASLT
(
hipblasLtMatmulDescDestroy
(
operationDesc
));
}
class
userArgsManager
{
class
userArgsManager
{
public:
public:
...
...
transformer_engine/common/include/transformer_engine/gemm.h
View file @
47077129
...
@@ -152,12 +152,7 @@ void nvte_cublas_batchgemm(const NVTETensor A, const NVTETensor B, NVTETensor D,
...
@@ -152,12 +152,7 @@ void nvte_cublas_batchgemm(const NVTETensor A, const NVTETensor B, NVTETensor D,
NVTETensor
workspace
,
bool
accumulate
,
bool
use_split_accumulator
,
NVTETensor
workspace
,
bool
accumulate
,
bool
use_split_accumulator
,
int
math_sm_count
,
int
batch_count
,
cudaStream_t
stream
);
int
math_sm_count
,
int
batch_count
,
cudaStream_t
stream
);
void
nvte_cublas_batchgemm_v2
(
const
NVTETensor
A
,
const
NVTETensor
B
,
NVTETensor
D
,
const
NVTETensor
bias
,
void
nvte_cublas_batchgemm_tensorwise_int8
(
const
NVTETensor
A
,
const
NVTETensor
B
,
const
NVTETensor
A_scales
,
const
NVTETensor
B_scales
,
NVTETensor
D
,
const
NVTETensor
bias
,
NVTETensor
pre_gelu_out
,
bool
transa
,
bool
transb
,
bool
grad
,
NVTETensor
workspace
,
bool
accumulate
,
bool
use_split_accumulator
,
int
math_sm_count
,
int
batch_count
,
cudaStream_t
stream
);
void
nvte_cublas_batchgemm_v3
(
const
NVTETensor
A
,
const
NVTETensor
B
,
const
NVTETensor
A_scales
,
const
NVTETensor
B_scales
,
NVTETensor
D
,
const
NVTETensor
bias
,
NVTETensor
pre_gelu_out
,
bool
transa
,
bool
transb
,
bool
grad
,
NVTETensor
pre_gelu_out
,
bool
transa
,
bool
transb
,
bool
grad
,
NVTETensor
workspace
,
bool
accumulate
,
bool
use_split_accumulator
,
NVTETensor
workspace
,
bool
accumulate
,
bool
use_split_accumulator
,
int
math_sm_count
,
int
batch_count
,
cudaStream_t
stream
);
int
math_sm_count
,
int
batch_count
,
cudaStream_t
stream
);
...
...
transformer_engine/pytorch/csrc/extensions/gemm.cpp
View file @
47077129
...
@@ -588,7 +588,7 @@ std::vector<py::object> generic_batchgemm(py::handle A, bool transa, py::handle
...
@@ -588,7 +588,7 @@ std::vector<py::object> generic_batchgemm(py::handle A, bool transa, py::handle
}
else
{
}
else
{
// Launch GEMM
// Launch GEMM
NVTE_SCOPED_GIL_RELEASE
({
NVTE_SCOPED_GIL_RELEASE
({
nvte_cublas_batchgemm
_v2
(
A_tensor
.
data
(),
B_tensor
.
data
(),
D_tensor
.
data
(),
bias_tensor
.
data
(),
nvte_cublas_batchgemm
(
A_tensor
.
data
(),
B_tensor
.
data
(),
D_tensor
.
data
(),
bias_tensor
.
data
(),
te_pre_gelu_out
.
data
(),
transa
,
transb
,
grad
,
te_workspace
.
data
(),
te_pre_gelu_out
.
data
(),
transa
,
transb
,
grad
,
te_workspace
.
data
(),
accumulate
,
use_split_accumulator
,
num_math_sms
,
batch_count
,
main_stream
);
accumulate
,
use_split_accumulator
,
num_math_sms
,
batch_count
,
main_stream
);
});
});
...
@@ -724,7 +724,7 @@ std::vector<py::object> tensorwise_int8_batchgemm(py::handle A, bool transa, py:
...
@@ -724,7 +724,7 @@ std::vector<py::object> tensorwise_int8_batchgemm(py::handle A, bool transa, py:
}
else
{
}
else
{
// Launch GEMM
// Launch GEMM
NVTE_SCOPED_GIL_RELEASE
({
NVTE_SCOPED_GIL_RELEASE
({
nvte_cublas_batchgemm_
v3
(
A_tensor
.
data
(),
B_tensor
.
data
(),
A_scales_tensor
.
data
(),
B_scales_tensor
.
data
(),
D_tensor
.
data
(),
bias_tensor
.
data
(),
nvte_cublas_batchgemm_
tensorwise_int8
(
A_tensor
.
data
(),
B_tensor
.
data
(),
A_scales_tensor
.
data
(),
B_scales_tensor
.
data
(),
D_tensor
.
data
(),
bias_tensor
.
data
(),
te_pre_gelu_out
.
data
(),
transa
,
transb
,
grad
,
te_workspace
.
data
(),
te_pre_gelu_out
.
data
(),
transa
,
transb
,
grad
,
te_workspace
.
data
(),
accumulate
,
use_split_accumulator
,
num_math_sms
,
batch_count
,
main_stream
);
accumulate
,
use_split_accumulator
,
num_math_sms
,
batch_count
,
main_stream
);
});
});
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment