Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
9615983e
Commit
9615983e
authored
Dec 09, 2021
by
Masaki Kozuki
Committed by
hubertlu-tw
Dec 09, 2021
Browse files
Remove `THCState` from `apex/contrib/multihead_attn` (#1239)
* pass `self.mask_additive` * clang-format * removing THCState
parent
d11ddccf
Changes
21
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
45 additions
and
123 deletions
+45
-123
apex/contrib/csrc/multihead_attn/strided_batched_gemm.h
apex/contrib/csrc/multihead_attn/strided_batched_gemm.h
+45
-123
No files found.
apex/contrib/csrc/multihead_attn/strided_batched_gemm.h
View file @
9615983e
#include <vector>
#include <iostream>
#include <vector>
#include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h>
//#include <cuda_profiler_api.h>
#include <cuda_runtime.h>
//#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/Exceptions.h>
#include "cutlass/cutlass.h"
...
...
@@ -15,7 +15,6 @@
#include "cutlass/gemm/wmma_gemm_traits.h"
// symbol to be automatically resolved by PyTorch libs
extern
THCState
*
state
;
rocblas_datatype
a_type
=
rocblas_datatype_f16_r
;
rocblas_datatype
b_type
=
rocblas_datatype_f16_r
;
...
...
@@ -29,16 +28,19 @@ rocblas_int flags = 0;
cublasOperation_t
convertTransToCublasOperation
(
char
trans
)
{
if
(
trans
==
't'
)
return
CUBLAS_OP_T
;
else
if
(
trans
==
'n'
)
return
CUBLAS_OP_N
;
else
if
(
trans
==
'c'
)
return
CUBLAS_OP_C
;
if
(
trans
==
't'
)
return
CUBLAS_OP_T
;
else
if
(
trans
==
'n'
)
return
CUBLAS_OP_N
;
else
if
(
trans
==
'c'
)
return
CUBLAS_OP_C
;
else
{
AT_ERROR
(
"trans must be one of: t, n, c"
);
return
CUBLAS_OP_T
;
}
}
void
RocblasStridedBatchedGemm
(
THCState
*
state
,
char
transa
,
char
transb
,
long
m
,
long
n
,
long
k
,
void
RocblasStridedBatchedGemm
(
char
transa
,
char
transb
,
long
m
,
long
n
,
long
k
,
float
alpha
,
const
half
*
a
,
long
lda
,
long
strideA
,
const
half
*
b
,
long
ldb
,
long
strideB
,
float
beta
,
half
*
c
,
long
ldc
,
long
strideC
,
half
*
d
,
long
ldd
,
long
strideD
,
long
batchCount
,
rocblas_gemm_algo
algo
)
{
cublasOperation_t
opa
=
convertTransToCublasOperation
(
transa
);
...
...
@@ -59,151 +61,71 @@ void RocblasStridedBatchedGemm(THCState *state, char transa, char transb, long m
(
int
)
batchCount
,
compute_type
,
algo
,
solution_index
,
flags
));
}
void
gemm_switch_fp32accum
(
THCState
*
state
,
char
transa
,
char
transb
,
long
m
,
long
n
,
long
k
,
void
gemm_switch_fp32accum
(
char
transa
,
char
transb
,
long
m
,
long
n
,
long
k
,
float
alpha
,
const
half
*
a
,
long
lda
,
long
strideA
,
const
half
*
b
,
long
ldb
,
long
strideB
,
float
beta
,
half
*
c
,
long
ldc
,
long
strideC
,
half
*
d
,
long
ldd
,
long
strideD
,
long
batchCount
)
{
auto
stream
=
c10
::
cuda
::
getCurrentCUDAStream
();
if
(
(
transa
==
't'
)
&&
(
transb
==
'n'
)
)
{
if
(
!
(
lda
&
0x7
)
&&
!
(
ldb
&
0x7
)
&&
!
(
ldc
&
0x7
))
{
RocblasStridedBatchedGemm
(
state
,
transa
,
transb
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
d
,
ldd
,
strideD
,
batchCount
,
algo
);
}
else
{
RocblasStridedBatchedGemm
(
state
,
transa
,
transb
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
d
,
ldd
,
strideD
,
batchCount
,
algo
);
}
if
(
!
(
lda
&
0x7
)
&&
!
(
ldb
&
0x7
)
&&
!
(
ldc
&
0x7
))
{
RocblasStridedBatchedGemm
(
transa
,
transb
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
d
,
ldd
,
strideD
,
batchCount
,
algo
);
}
else
{
RocblasStridedBatchedGemm
(
transa
,
transb
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
d
,
ldd
,
strideD
,
batchCount
,
algo
);
}
}
else
if
(
(
transa
==
'n'
)
&&
(
transb
==
'n'
)
)
{
if
(
!
(
lda
&
0x7
)
&&
!
(
ldb
&
0x7
)
&&
!
(
ldc
&
0x7
))
{
RocblasStridedBatchedGemm
(
state
,
transa
,
transb
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
d
,
ldd
,
strideD
,
batchCount
,
algo
);
}
else
{
RocblasStridedBatchedGemm
(
state
,
transa
,
transb
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
d
,
ldd
,
strideD
,
batchCount
,
algo
);
}
if
(
!
(
lda
&
0x7
)
&&
!
(
ldb
&
0x7
)
&&
!
(
ldc
&
0x7
))
{
RocblasStridedBatchedGemm
(
transa
,
transb
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
d
,
ldd
,
strideD
,
batchCount
,
algo
);
}
else
{
RocblasStridedBatchedGemm
(
transa
,
transb
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
d
,
ldd
,
strideD
,
batchCount
,
algo
);
}
}
else
if
(
(
transa
==
'n'
)
&&
(
transb
==
't'
)
)
{
if
(
!
(
lda
&
0x7
)
&&
!
(
ldb
&
0x7
)
&&
!
(
ldc
&
0x7
))
{
RocblasStridedBatchedGemm
(
state
,
transa
,
transb
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
d
,
ldd
,
strideD
,
batchCount
,
algo
);
}
else
{
RocblasStridedBatchedGemm
(
state
,
transa
,
transb
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
d
,
ldd
,
strideD
,
batchCount
,
algo
);
}
if
(
!
(
lda
&
0x7
)
&&
!
(
ldb
&
0x7
)
&&
!
(
ldc
&
0x7
))
{
RocblasStridedBatchedGemm
(
transa
,
transb
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
d
,
ldd
,
strideD
,
batchCount
,
algo
);
}
else
{
RocblasStridedBatchedGemm
(
transa
,
transb
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
d
,
ldd
,
strideD
,
batchCount
,
algo
);
}
}
else
{
AT_ASSERTM
(
false
,
"TransA and TransB are invalid"
);
}
}
void
adjustLdLevel3
(
char
transa
,
char
transb
,
int64_t
m
,
int64_t
n
,
int64_t
k
,
int64_t
*
lda
,
int64_t
*
ldb
,
int64_t
*
ldc
)
{
void
adjustLdLevel3
(
char
transa
,
char
transb
,
int64_t
m
,
int64_t
n
,
int64_t
k
,
int64_t
*
lda
,
int64_t
*
ldb
,
int64_t
*
ldc
)
{
int
transa_
=
((
transa
==
't'
)
||
(
transa
==
'T'
));
int
transb_
=
((
transb
==
't'
)
||
(
transb
==
'T'
));
// Note: leading dimensions generally are checked that they are > 0 and at
least as big the result
// requires (even if the value won't be used).
if
(
n
<=
1
)
// Note: leading dimensions generally are checked that they are > 0 and at
//
least as big the result
requires (even if the value won't be used).
if
(
n
<=
1
)
*
ldc
=
std
::
max
<
int64_t
>
(
m
,
1
);
if
(
transa_
)
{
if
(
m
<=
1
)
if
(
transa_
)
{
if
(
m
<=
1
)
*
lda
=
std
::
max
<
int64_t
>
(
k
,
1
);
}
else
{
if
(
k
<=
1
)
}
else
{
if
(
k
<=
1
)
*
lda
=
std
::
max
<
int64_t
>
(
m
,
1
);
}
if
(
transb_
)
{
if
(
k
<=
1
)
if
(
transb_
)
{
if
(
k
<=
1
)
*
ldb
=
std
::
max
<
int64_t
>
(
n
,
1
);
}
else
{
if
(
n
<=
1
)
}
else
{
if
(
n
<=
1
)
*
ldb
=
std
::
max
<
int64_t
>
(
k
,
1
);
}
}
void
HgemmStridedBatched
(
THCState
*
state
,
char
transa
,
char
transb
,
long
m
,
long
n
,
long
k
,
float
alpha
,
const
half
*
a
,
long
lda
,
long
strideA
,
const
half
*
b
,
long
ldb
,
long
strideB
,
float
beta
,
half
*
c
,
long
ldc
,
long
strideC
,
half
*
d
,
long
ldd
,
long
strideD
,
long
batchCount
)
{
if
(
(
m
>=
INT_MAX
)
||
(
n
>=
INT_MAX
)
||
(
k
>=
INT_MAX
)
||
(
lda
>=
INT_MAX
)
||
(
ldb
>=
INT_MAX
)
||
(
ldc
>=
INT_MAX
)
||
(
batchCount
>=
INT_MAX
)
)
void
HgemmStridedBatched
(
char
transa
,
char
transb
,
long
m
,
long
n
,
long
k
,
float
alpha
,
const
half
*
a
,
long
lda
,
long
strideA
,
const
half
*
b
,
long
ldb
,
long
strideB
,
float
beta
,
half
*
c
,
long
ldc
,
long
strideC
,
long
batchCount
)
{
if
((
m
>=
INT_MAX
)
||
(
n
>=
INT_MAX
)
||
(
k
>=
INT_MAX
)
||
(
lda
>=
INT_MAX
)
||
(
ldb
>=
INT_MAX
)
||
(
ldc
>=
INT_MAX
)
||
(
batchCount
>=
INT_MAX
))
{
AT_ERROR
(
"Cublas_SgemmStridedBatched only supports m, n, k, lda, ldb, ldc, batchCount"
"with the bound [val] <= %d"
,
INT_MAX
);
AT_ERROR
(
"Cublas_SgemmStridedBatched only supports m, n, k, lda, ldb, ldc, "
"batchCount"
"with the bound [val] <= %d"
,
INT_MAX
);
}
adjustLdLevel3
(
transa
,
transb
,
m
,
n
,
k
,
&
lda
,
&
ldb
,
&
ldc
);
//gemm_switch(state, transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount);
gemm_switch_fp32accum
(
state
,
transa
,
transb
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
d
,
ldd
,
strideD
,
batchCount
);
}
/******
at::Tensor strided_batched_gemm_cuda(
float beta,
at::Tensor in_result,
float alpha,
at::Tensor batch1,
at::Tensor batch2) {
bool transpose_result;
char transpose_batch1, transpose_batch2;
int64_t lda, ldb, ldc;
at::Tensor result, input1, input2;
if (in_result.stride(1) == 1)
{
transpose_result = false;
result = in_result;
ldc = result.stride(2);
}
else if (in_result.stride(2) == 1)
{
transpose_result = true;
at::Tensor swap = batch2;
batch2 = batch1;
batch1 = swap;
result = in_result;
ldc = result.stride(1);
} else {
AT_ASSERTM(false, "result should be contiguous");
}
if (batch1.stride(transpose_result ? 2 : 1) == 1 &&
batch1.stride(transpose_result ? 1 : 2) != 0) {
transpose_batch1 = 'n';
input1 = batch1;
lda = input1.stride(transpose_result ? 1 : 2);
} else if (batch1.stride(transpose_result ? 1 : 2) == 1 &&
batch1.stride(transpose_result ? 2 : 1) != 0) {
transpose_batch1 = 't';
input1 = batch1;
lda = input1.stride(transpose_result ? 2 : 1);
} else {
AT_ASSERTM(false, "input1 should be contiguous");
}
if (batch2.stride(transpose_result ? 2 : 1) == 1 &&
batch2.stride(transpose_result ? 1 : 2) != 0) {
transpose_batch2 = 'n';
input2 = batch2;
ldb = input2.stride(transpose_result ? 1 : 2);
} else if (batch2.stride(transpose_result ? 1 : 2) == 1 &&
batch2.stride(transpose_result ? 2 : 1) != 0) {
transpose_batch2 = 't';
input2 = batch2;
ldb = input2.stride(transpose_result ? 2 : 1);
} else {
AT_ASSERTM(false, "input2 should be contiguous");
}
int64_t num_batches = result.size(0);
HgemmStridedBatched(
state,
transpose_batch1,
transpose_batch2,
result.size(transpose_result ? 2 : 1),
result.size(transpose_result ? 1 : 2),
input1.size(transpose_result ? 1 : 2),
alpha,
static_cast<const half*>(input1.data_ptr()), lda, input1.stride(0),
static_cast<const half*>(input2.data_ptr()), ldb, input2.stride(0),
beta,
static_cast<half*>(result.data_ptr()), ldc, result.stride(0),
num_batches);
return in_result;
// gemm_switch_fp32accum(transa, transb, m, n, k, alpha, a, lda, strideA,
// b, ldb, strideB, beta, c, ldc, strideC, batchCount);
gemm_switch_fp32accum
(
transa
,
transb
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
d
,
ldd
,
strideD
,
batchCount
);
}
***/
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment