Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
7eed2594
Commit
7eed2594
authored
Sep 28, 2023
by
flyingdown
Browse files
128/256前向使用mmac指令重写gemm
parent
0816a70e
Changes
15
Show whitespace changes
Inline
Side-by-side
Showing
15 changed files
with
894 additions
and
141 deletions
+894
-141
apex/contrib/csrc/fmha/fmha_api.cpp
apex/contrib/csrc/fmha/fmha_api.cpp
+30
-12
apex/contrib/csrc/fmha/src/fmha/gemm.h
apex/contrib/csrc/fmha/src/fmha/gemm.h
+73
-87
apex/contrib/csrc/fmha/src/fmha/gmem_tile.h
apex/contrib/csrc/fmha/src/fmha/gmem_tile.h
+40
-0
apex/contrib/csrc/fmha/src/fmha/mask.h
apex/contrib/csrc/fmha/src/fmha/mask.h
+21
-0
apex/contrib/csrc/fmha/src/fmha/smem_tile.h
apex/contrib/csrc/fmha/src/fmha/smem_tile.h
+193
-11
apex/contrib/csrc/fmha/src/fmha/softmax.h
apex/contrib/csrc/fmha/src/fmha/softmax.h
+170
-1
apex/contrib/csrc/fmha/src/fmha/utils.h
apex/contrib/csrc/fmha/src/fmha/utils.h
+84
-7
apex/contrib/csrc/fmha/src/fmha_dgrad_fp16_128_64_kernel.sm80.cu
...ntrib/csrc/fmha/src/fmha_dgrad_fp16_128_64_kernel.sm80.cu
+4
-0
apex/contrib/csrc/fmha/src/fmha_fprop_fp16_128_64_kernel.sm80.cu
...ntrib/csrc/fmha/src/fmha_fprop_fp16_128_64_kernel.sm80.cu
+4
-0
apex/contrib/csrc/fmha/src/fmha_fprop_fp16_256_64_kernel.sm80.cu
...ntrib/csrc/fmha/src/fmha_fprop_fp16_256_64_kernel.sm80.cu
+4
-0
apex/contrib/csrc/fmha/src/fmha_fprop_fp16_384_64_kernel.sm80.cu
...ntrib/csrc/fmha/src/fmha_fprop_fp16_384_64_kernel.sm80.cu
+4
-0
apex/contrib/csrc/fmha/src/fmha_fprop_fp16_512_64_kernel.sm80.cu
...ntrib/csrc/fmha/src/fmha_fprop_fp16_512_64_kernel.sm80.cu
+4
-0
apex/contrib/csrc/fmha/src/fmha_fprop_kernel_1xN.h
apex/contrib/csrc/fmha/src/fmha_fprop_kernel_1xN.h
+255
-17
apex/contrib/csrc/fmha/src/fmha_kernel.h
apex/contrib/csrc/fmha/src/fmha_kernel.h
+2
-0
setup.py
setup.py
+6
-6
No files found.
apex/contrib/csrc/fmha/fmha_api.cpp
View file @
7eed2594
...
@@ -80,7 +80,11 @@ void set_params(Fused_multihead_attention_fprop_params ¶ms,
...
@@ -80,7 +80,11 @@ void set_params(Fused_multihead_attention_fprop_params ¶ms,
params
.
p_dropout
=
1.
f
-
p_dropout
;
params
.
p_dropout
=
1.
f
-
p_dropout
;
params
.
rp_dropout
=
1.
f
/
params
.
p_dropout
;
params
.
rp_dropout
=
1.
f
/
params
.
p_dropout
;
TORCH_CHECK
(
p_dropout
<
1.
f
);
TORCH_CHECK
(
p_dropout
<
1.
f
);
#if defined (__HIP_PLATFORM_HCC__)
set_alpha
(
params
.
scale_dropout
,
params
.
rp_dropout
,
acc_type
);
#else
set_alpha
(
params
.
scale_dropout
,
params
.
rp_dropout
,
data_type
);
set_alpha
(
params
.
scale_dropout
,
params
.
rp_dropout
,
data_type
);
#endif
}
}
std
::
vector
<
at
::
Tensor
>
std
::
vector
<
at
::
Tensor
>
...
@@ -94,24 +98,38 @@ mha_fwd(const at::Tensor &qkv, // total x num_heads x 3 x head_size, tot
...
@@ -94,24 +98,38 @@ mha_fwd(const at::Tensor &qkv, // total x num_heads x 3 x head_size, tot
c10
::
optional
<
at
::
Generator
>
gen_
)
{
c10
::
optional
<
at
::
Generator
>
gen_
)
{
auto
dprops
=
at
::
cuda
::
getCurrentDeviceProperties
();
auto
dprops
=
at
::
cuda
::
getCurrentDeviceProperties
();
#if not defined(__HIP_PLATFORM_HCC__)
TORCH_CHECK
(
dprops
->
major
==
8
&&
dprops
->
minor
==
0
);
TORCH_CHECK
(
dprops
->
major
==
8
&&
dprops
->
minor
==
0
);
#endif
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
Launch_params
<
Fused_multihead_attention_fprop_params
>
launch_params
(
dprops
,
stream
,
is_training
,
is_nl
);
Launch_params
<
Fused_multihead_attention_fprop_params
>
launch_params
(
dprops
,
stream
,
is_training
,
is_nl
);
int
seq_len
=
512
;
// int seq_len = 512;
auto
launch
=
&
run_fmha_fp16_512_64_sm80
;
// auto launch = &run_fmha_fp16_512_64_sm80;
// if( max_seq_len <= 128 ) {
// seq_len = 128;
// launch = &run_fmha_fp16_128_64_sm80;
// } else if( max_seq_len <= 256 ) {
// seq_len = 256;
// launch = &run_fmha_fp16_256_64_sm80;
// } else if( max_seq_len <= 384 ) {
// seq_len = 384;
// launch = &run_fmha_fp16_384_64_sm80;
// } else if( max_seq_len <= 512 ) {
// seq_len = 512;
// launch = &run_fmha_fp16_512_64_sm80;
// } else {
// TORCH_CHECK(false);
// }
int
seq_len
=
256
;
auto
launch
=
&
run_fmha_fp16_256_64_sm80
;
if
(
max_seq_len
<=
128
)
{
if
(
max_seq_len
<=
128
)
{
seq_len
=
128
;
seq_len
=
128
;
launch
=
&
run_fmha_fp16_128_64_sm80
;
launch
=
&
run_fmha_fp16_128_64_sm80
;
}
else
if
(
max_seq_len
<=
256
)
{
}
else
if
(
max_seq_len
<=
256
)
{
seq_len
=
256
;
seq_len
=
256
;
launch
=
&
run_fmha_fp16_256_64_sm80
;
launch
=
&
run_fmha_fp16_256_64_sm80
;
}
else
if
(
max_seq_len
<=
384
)
{
seq_len
=
384
;
launch
=
&
run_fmha_fp16_384_64_sm80
;
}
else
if
(
max_seq_len
<=
512
)
{
seq_len
=
512
;
launch
=
&
run_fmha_fp16_512_64_sm80
;
}
else
{
}
else
{
TORCH_CHECK
(
false
);
TORCH_CHECK
(
false
);
}
}
...
@@ -178,7 +196,7 @@ mha_fwd(const at::Tensor &qkv, // total x num_heads x 3 x head_size, tot
...
@@ -178,7 +196,7 @@ mha_fwd(const at::Tensor &qkv, // total x num_heads x 3 x head_size, tot
return
{
ctx
,
s
};
return
{
ctx
,
s
};
}
}
/*
std::vector<at::Tensor>
std::vector<at::Tensor>
mha_bwd(const at::Tensor &dout, // total x num_heads, x head_size
mha_bwd(const at::Tensor &dout, // total x num_heads, x head_size
const at::Tensor &qkv, // total x num_heads x 3 x head_size, total := \sum_{i=0}^{b} s_i
const at::Tensor &qkv, // total x num_heads x 3 x head_size, total := \sum_{i=0}^{b} s_i
...
@@ -351,11 +369,11 @@ std::vector<at::Tensor> mha_bwd_nl(const at::Tensor &dout, // total x num
...
@@ -351,11 +369,11 @@ std::vector<at::Tensor> mha_bwd_nl(const at::Tensor &dout, // total x num
dqkv.data_ptr(), dkv.data_ptr(), cu_seqlens.data_ptr<int>(), hidden_size, batch_size, total, num_chunks, stream);
dqkv.data_ptr(), dkv.data_ptr(), cu_seqlens.data_ptr<int>(), hidden_size, batch_size, total, num_chunks, stream);
return { dqkv, softmax, dkv };
return { dqkv, softmax, dkv };
}
}
*/
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
doc
()
=
"Fused Multi-head Self-attention for BERT"
;
m
.
doc
()
=
"Fused Multi-head Self-attention for BERT"
;
m
.
def
(
"fwd"
,
&
mha_fwd
,
"Forward pass"
);
m
.
def
(
"fwd"
,
&
mha_fwd
,
"Forward pass"
);
m
.
def
(
"bwd"
,
&
mha_bwd
,
"Backward pass"
);
//
m.def("bwd", &mha_bwd, "Backward pass");
m
.
def
(
"bwd_nl"
,
&
mha_bwd_nl
,
"Backward pass (small-batch)"
);
//
m.def("bwd_nl", &mha_bwd_nl, "Backward pass (small-batch)");
}
}
apex/contrib/csrc/fmha/src/fmha/gemm.h
View file @
7eed2594
...
@@ -145,85 +145,57 @@ struct Fragment_b : public Fragment<uint16_t, 8> {
...
@@ -145,85 +145,57 @@ struct Fragment_b : public Fragment<uint16_t, 8> {
};
};
////////////////////////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////////////////////////
#if defined (__HIP_PLATFORM_HCC__)
#if defined (__HIP_PLATFORM_HCC__)
__device__
inline
void
f16mulf16addf32
(
uint32_t
&
a
,
uint32_t
&
b
,
const
float
*
c
,
float
*
d
){
struct
Fragment_accumulator
:
public
Fragment
<
float
,
4
>
{
// uint32_t res = 0;
// asm volatile("v_pk_fma_f16 %0, %1,%2,%3" : "=v"(res) : "v"(a), "v"(b), "v"(res));
// __half * h = reinterpret_cast<__half*>(&res);
__half
*
ha
=
reinterpret_cast
<
__half
*>
(
&
a
);
__half
*
hb
=
reinterpret_cast
<
__half
*>
(
&
b
);
float
C
=
*
c
,
D
=
*
d
;
*
d
=
*
c
+
__half2float
(
ha
[
0
])
*
__half2float
(
hb
[
0
])
+
__half2float
(
ha
[
1
])
*
__half2float
(
hb
[
1
]);
// if (threadIdx.x == 15) {
// printf("f16mulf16addf32 %i: A %f, %f, B %f, %f, RES %f, %f, %f, C %f, %f, D %f, %f \n", threadIdx.x,
// __half2float(ha[0]), __half2float(ha[1]),
// __half2float(hb[0]), __half2float(hb[1]),
// __half2float(ha[0])*__half2float(hb[0]),
// __half2float(ha[1])*__half2float(hb[1]),
// __half2float(ha[0])*__half2float(hb[0]) + __half2float(ha[1])*__half2float(hb[1]),
// C, *c, D, *d
// );
// }
}
// row 8 col 4
// The base class.
__device__
inline
void
m16n8k16
(
const
uint32_t
*
A
,
const
uint32_t
*
B
,
/*const float * C,*/
float
*
D
)
{
using
Base
=
Fragment
<
float
,
8
>
;
int
tid
=
threadIdx
.
x
;
int
baseId
=
tid
/
32
*
32
;
// Add two fragments.
__shared__
uint32_t
smem
[
256
*
6
];
template
<
typename
Other_fragment_
>
inline
__device__
void
add
(
const
Other_fragment_
&
other
)
{
int
base
=
tid
*
6
;
for
(
int
ii
=
0
;
ii
<
Base
::
NUM_ELTS
;
++
ii
)
{
__builtin_memcpy
(
smem
+
base
,
A
,
sizeof
(
uint32_t
));
this
->
elt
(
ii
)
=
this
->
elt
(
ii
)
+
other
.
elt
(
ii
);
__builtin_memcpy
(
smem
+
(
base
+
1
),
A
+
1
,
sizeof
(
uint32_t
));
}
__builtin_memcpy
(
smem
+
(
base
+
2
),
A
+
2
,
sizeof
(
uint32_t
));
__builtin_memcpy
(
smem
+
(
base
+
3
),
A
+
3
,
sizeof
(
uint32_t
));
__builtin_memcpy
(
smem
+
(
base
+
4
),
B
,
sizeof
(
uint32_t
));
__builtin_memcpy
(
smem
+
(
base
+
5
),
B
+
1
,
sizeof
(
uint32_t
));
__syncthreads
();
/* 站在D的视角,每个进程负责D数据的计算,从0线程开始循环,获取一行A和两列B
s为B矩阵的线程号
baseA为A的线程号
baseB0为当前线程获取B的第一列,baseB1为当前线程获取B的第二列
*/
int
s
=
baseId
+
(
tid
%
4
)
*
8
,
e
=
s
+
4
;
for
(
int
i
=
s
;
i
<
e
;
++
i
)
{
// A[0]->i A[1]->i+1 A[2]->i+2 A[3]->i+3 B[0]->i+4 B[1]->i+5
int
baseA
=
(
tid
-
tid
%
4
+
i
-
s
)
*
6
;
// 当前tid所处行的第一列的进程号+stride 再*6
int
baseB0
=
i
*
6
,
baseB1
=
(
i
+
4
)
*
6
;
f16mulf16addf32
(
smem
[
baseA
],
smem
[
baseB0
+
4
],
D
,
D
);
f16mulf16addf32
(
smem
[
baseA
+
2
],
smem
[
baseB0
+
5
],
D
,
D
);
f16mulf16addf32
(
smem
[
baseA
],
smem
[
baseB1
+
4
],
D
+
1
,
D
+
1
);
f16mulf16addf32
(
smem
[
baseA
+
2
],
smem
[
baseB1
+
5
],
D
+
1
,
D
+
1
);
f16mulf16addf32
(
smem
[
baseA
+
1
],
smem
[
baseB0
+
4
],
D
+
2
,
D
+
2
);
f16mulf16addf32
(
smem
[
baseA
+
3
],
smem
[
baseB0
+
5
],
D
+
2
,
D
+
2
);
f16mulf16addf32
(
smem
[
baseA
+
1
],
smem
[
baseB1
+
4
],
D
+
3
,
D
+
3
);
f16mulf16addf32
(
smem
[
baseA
+
3
],
smem
[
baseB1
+
5
],
D
+
3
,
D
+
3
);
}
}
// __half * a0 = reinterpret_cast<__half*>(smem+base);
// Do the HMMA.
// __half * a1 = reinterpret_cast<__half*>(smem+base+1);
template
<
typename
Layout_a
,
typename
Layout_b
>
// __half * a2 = reinterpret_cast<__half*>(smem+base+2);
inline
__device__
void
mma
(
const
Fragment_a
<
Layout_a
>
&
a
,
// __half * a3 = reinterpret_cast<__half*>(smem+base+3);
const
Fragment_b
<
Layout_b
>
&
b
)
{
// __half * b0 = reinterpret_cast<__half*>(smem+base+4);
// const uint32_t * A = reinterpret_cast<const uint32_t*>(a.regs_);
// __half * b1 = reinterpret_cast<__half*>(smem+base+5);
// const uint32_t * B = reinterpret_cast<const uint32_t*>(b.regs_);
// printf("m16n8k16 %i: \n A %f, %f, %f, %f, %f, %f, %f, %f \n B %f, %f, %f, %f \n D %f, %f, %f, %f \n", threadIdx.x,
// float * D = reinterpret_cast<float*>(regs_);
// __half2float(a0[0]), __half2float(a0[1]),
// float regs[8];
// __half2float(a1[0]), __half2float(a1[1]),
// __builtin_memcpy(regs, D, sizeof(float)*8);
// __half2float(a2[0]), __half2float(a2[1]),
// m16n8k16(A, B, D);
// __half2float(a3[0]), __half2float(a3[1]),
// m16n8k16(A, B+2, D+4);
// __half2float(b0[0]), __half2float(b0[1]),
using
v4f
=
__attribute__
(
(
__vector_size__
(
4
*
sizeof
(
float
))
))
float
;
// __half2float(b1[0]), __half2float(b1[1]),
v4f
*
rC
=
reinterpret_cast
<
v4f
*>
(
regs_
);
// D[0], D[1], D[2], D[3]
// float rA = reinterpret_cast<const float&>(a.reg(0));
// );
// float rB = reinterpret_cast<const float&>(b.reg(0));
}
float
rA0
=
a
.
template
elt_as
<
float
>(
0
);
#endif
float
rB0
=
b
.
template
elt_as
<
float
>(
0
);
*
rC
=
__builtin_amdgcn_mmac_f32_16x16x4f32
(
rA0
,
rB0
,
*
rC
,
0
);
float
rA1
=
a
.
template
elt_as
<
float
>(
1
);
float
rB1
=
b
.
template
elt_as
<
float
>(
1
);
*
rC
=
__builtin_amdgcn_mmac_f32_16x16x4f32
(
rA1
,
rB1
,
*
rC
,
0
);
float
rA2
=
a
.
template
elt_as
<
float
>(
2
);
float
rB2
=
b
.
template
elt_as
<
float
>(
2
);
*
rC
=
__builtin_amdgcn_mmac_f32_16x16x4f32
(
rA2
,
rB2
,
*
rC
,
0
);
float
rA3
=
a
.
template
elt_as
<
float
>(
3
);
float
rB3
=
b
.
template
elt_as
<
float
>(
3
);
*
rC
=
__builtin_amdgcn_mmac_f32_16x16x4f32
(
rA3
,
rB3
,
*
rC
,
0
);
// if (blockIdx.x == 0) {
// printf("tid:%d rA0:%6.4f rB0:%6.4f rA1:%6.4f rB1:%6.4f rA2:%6.4f rB2:%6.4f rA3:%6.4f rB3:%6.4f c0:%6.4f c1:%6.4f c2:%6.4f c3:%6.4f\n", threadIdx.x,
// rA0, rB0, rA1, rB1, rA2, rB2, rA3, rB3, elt(0), elt(1), elt(2), elt(3));
// }
}
};
#else
////////////////////////////////////////////////////////////////////////////////////////////////////
struct
Fragment_accumulator
:
public
Fragment
<
float
,
8
>
{
struct
Fragment_accumulator
:
public
Fragment
<
float
,
8
>
{
...
@@ -242,15 +214,6 @@ struct Fragment_accumulator : public Fragment<float, 8> {
...
@@ -242,15 +214,6 @@ struct Fragment_accumulator : public Fragment<float, 8> {
template
<
typename
Layout_a
,
typename
Layout_b
>
template
<
typename
Layout_a
,
typename
Layout_b
>
inline
__device__
void
mma
(
const
Fragment_a
<
Layout_a
>
&
a
,
inline
__device__
void
mma
(
const
Fragment_a
<
Layout_a
>
&
a
,
const
Fragment_b
<
Layout_b
>
&
b
)
{
const
Fragment_b
<
Layout_b
>
&
b
)
{
#if defined (__HIP_PLATFORM_HCC__)
const
uint32_t
*
A
=
reinterpret_cast
<
const
uint32_t
*>
(
a
.
regs_
);
const
uint32_t
*
B
=
reinterpret_cast
<
const
uint32_t
*>
(
b
.
regs_
);
float
*
D
=
reinterpret_cast
<
float
*>
(
regs_
);
float
regs
[
8
];
__builtin_memcpy
(
regs
,
D
,
sizeof
(
float
)
*
8
);
m16n8k16
(
A
,
B
,
D
);
m16n8k16
(
A
,
B
+
2
,
D
+
4
);
#else
asm
volatile
(
\
asm
volatile
(
\
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32
\n
"
\
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32
\n
"
\
" {%0, %1, %2, %3},
\n
"
\
" {%0, %1, %2, %3},
\n
"
\
...
@@ -269,11 +232,10 @@ struct Fragment_accumulator : public Fragment<float, 8> {
...
@@ -269,11 +232,10 @@ struct Fragment_accumulator : public Fragment<float, 8> {
:
"+f"
(
elt
(
4
)),
"+f"
(
elt
(
5
)),
"+f"
(
elt
(
6
)),
"+f"
(
elt
(
7
))
:
"+f"
(
elt
(
4
)),
"+f"
(
elt
(
5
)),
"+f"
(
elt
(
6
)),
"+f"
(
elt
(
7
))
:
"r"
(
a
.
reg
(
0
)),
"r"
(
a
.
reg
(
1
)),
"r"
(
a
.
reg
(
2
)),
"r"
(
a
.
reg
(
3
))
:
"r"
(
a
.
reg
(
0
)),
"r"
(
a
.
reg
(
1
)),
"r"
(
a
.
reg
(
2
)),
"r"
(
a
.
reg
(
3
))
,
"r"
(
b
.
reg
(
2
)),
"r"
(
b
.
reg
(
3
)));
,
"r"
(
b
.
reg
(
2
)),
"r"
(
b
.
reg
(
3
)));
#endif
}
}
};
};
#endif
////////////////////////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
Fragment
,
int
M
,
int
N
>
template
<
typename
Fragment
,
int
M
,
int
N
>
...
@@ -310,8 +272,24 @@ inline __device__ void gemm(Acc (&acc)[M][N], const A (&a)[M], const B (&b)[N])
...
@@ -310,8 +272,24 @@ inline __device__ void gemm(Acc (&acc)[M][N], const A (&a)[M], const B (&b)[N])
#pragma unroll
#pragma unroll
for
(
int
mi
=
0
;
mi
<
M
;
++
mi
)
{
for
(
int
mi
=
0
;
mi
<
M
;
++
mi
)
{
// wangaq debug
// if (blockIdx.x == 0) {
// printf("a tid:%d mi:%d %6.4f %6.4f %6.4f %6.4f\n", threadIdx.x, mi,
// a[mi].template elt_as<float>(0),
// a[mi].template elt_as<float>(1),
// a[mi].template elt_as<float>(2),
// a[mi].template elt_as<float>(3));
// }
#pragma unroll
#pragma unroll
for
(
int
ni
=
0
;
ni
<
N
;
++
ni
)
{
for
(
int
ni
=
0
;
ni
<
N
;
++
ni
)
{
// wangaq debug
// if (blockIdx.x == 0) {
// printf("b tid:%d ni:%d %6.4f %6.4f %6.4f %6.4f\n", threadIdx.x, ni,
// b[ni].template elt_as<float>(0),
// b[ni].template elt_as<float>(1),
// b[ni].template elt_as<float>(2),
// b[ni].template elt_as<float>(3));
// }
acc
[
mi
][
ni
].
mma
(
a
[
mi
],
b
[
ni
]);
acc
[
mi
][
ni
].
mma
(
a
[
mi
],
b
[
ni
]);
}
}
}
}
...
@@ -340,7 +318,11 @@ struct Cta_tile_ {
...
@@ -340,7 +318,11 @@ struct Cta_tile_ {
// The number of warps per CTA.
// The number of warps per CTA.
enum
{
WARPS_PER_CTA
=
WARPS_M
*
WARPS_N
*
WARPS_K
};
enum
{
WARPS_PER_CTA
=
WARPS_M
*
WARPS_N
*
WARPS_K
};
// The number of threads per warp.
// The number of threads per warp.
#if defined(__HIP_PLATFORM_HCC__)
enum
{
THREADS_PER_WARP
=
64
};
#else
enum
{
THREADS_PER_WARP
=
32
};
enum
{
THREADS_PER_WARP
=
32
};
#endif
// The number of threads per CTA.
// The number of threads per CTA.
enum
{
THREADS_PER_CTA
=
WARPS_PER_CTA
*
THREADS_PER_WARP
};
enum
{
THREADS_PER_CTA
=
WARPS_PER_CTA
*
THREADS_PER_WARP
};
};
};
...
@@ -350,7 +332,11 @@ struct Cta_tile_ {
...
@@ -350,7 +332,11 @@ struct Cta_tile_ {
template
<
typename
Cta_tile
>
template
<
typename
Cta_tile
>
struct
Hmma_tile
{
struct
Hmma_tile
{
// The number of elements computed with a single warp-MMA.
// The number of elements computed with a single warp-MMA.
// #if defined(__HIP_PLATFORM_HCC__)
// enum { M_PER_MMA = 16, N_PER_MMA = 16, K_PER_MMA = 4 };
// #else
enum
{
M_PER_MMA
=
16
,
N_PER_MMA
=
16
,
K_PER_MMA
=
16
};
enum
{
M_PER_MMA
=
16
,
N_PER_MMA
=
16
,
K_PER_MMA
=
16
};
// #endif
// The number of elements computed with a single CTA-MMA.
// The number of elements computed with a single CTA-MMA.
enum
{
enum
{
...
...
apex/contrib/csrc/fmha/src/fmha/gmem_tile.h
View file @
7eed2594
...
@@ -85,6 +85,20 @@ struct Gmem_tile_qkv {
...
@@ -85,6 +85,20 @@ struct Gmem_tile_qkv {
// Store data to shared memory.
// Store data to shared memory.
template
<
typename
Smem_tile
>
template
<
typename
Smem_tile
>
inline
__device__
void
commit
(
Smem_tile
&
smem_tile
)
{
inline
__device__
void
commit
(
Smem_tile
&
smem_tile
)
{
// wangaq debug
// for( int ii = 0; ii < LDGS; ++ii ) {
// if (blockIdx.x == 0) {
// printf("commit tid:%d LDGS:%d ii:%d %f %f %f %f %f %f %f %f\n", threadIdx.x, LDGS, ii,
// __half2float(reinterpret_cast<__half*>(&fetch_[ii])[0]),
// __half2float(reinterpret_cast<__half*>(&fetch_[ii])[1]),
// __half2float(reinterpret_cast<__half*>(&fetch_[ii])[2]),
// __half2float(reinterpret_cast<__half*>(&fetch_[ii])[3]),
// __half2float(reinterpret_cast<__half*>(&fetch_[ii])[4]),
// __half2float(reinterpret_cast<__half*>(&fetch_[ii])[5]),
// __half2float(reinterpret_cast<__half*>(&fetch_[ii])[6]),
// __half2float(reinterpret_cast<__half*>(&fetch_[ii])[7]));
// }
// }
smem_tile
.
store
(
fetch_
);
smem_tile
.
store
(
fetch_
);
}
}
...
@@ -105,6 +119,18 @@ struct Gmem_tile_qkv {
...
@@ -105,6 +119,18 @@ struct Gmem_tile_qkv {
#pragma unroll
#pragma unroll
for
(
int
ii
=
0
;
ii
<
LDGS
;
++
ii
)
{
for
(
int
ii
=
0
;
ii
<
LDGS
;
++
ii
)
{
fct
.
load
(
ii
,
preds
[
ii
]);
fct
.
load
(
ii
,
preds
[
ii
]);
// wangaq debug
// if (blockIdx.x == 0) {
// printf("load tid:%d LDGS:%d ii:%d %f %f %f %f %f %f %f %f\n", threadIdx.x, LDGS, ii,
// __half2float(reinterpret_cast<__half*>(&fetch_[ii])[0]),
// __half2float(reinterpret_cast<__half*>(&fetch_[ii])[1]),
// __half2float(reinterpret_cast<__half*>(&fetch_[ii])[2]),
// __half2float(reinterpret_cast<__half*>(&fetch_[ii])[3]),
// __half2float(reinterpret_cast<__half*>(&fetch_[ii])[4]),
// __half2float(reinterpret_cast<__half*>(&fetch_[ii])[5]),
// __half2float(reinterpret_cast<__half*>(&fetch_[ii])[6]),
// __half2float(reinterpret_cast<__half*>(&fetch_[ii])[7]));
// }
}
}
}
}
...
@@ -254,8 +280,13 @@ struct Gmem_tile_mma_sd {
...
@@ -254,8 +280,13 @@ struct Gmem_tile_mma_sd {
// The mma tile.
// The mma tile.
using
Mma_tile
=
fmha
::
Hmma_tile
<
Cta_tile
>
;
using
Mma_tile
=
fmha
::
Hmma_tile
<
Cta_tile
>
;
#if defined(__HIP_PLATFORM_HCC__)
// Each STG stores 16 elements.
enum
{
BYTES_PER_STG
=
BYTES_PER_ELEMENT
*
4
};
#else
// Each STG stores 8 elements.
// Each STG stores 8 elements.
enum
{
BYTES_PER_STG
=
BYTES_PER_ELEMENT
*
8
};
enum
{
BYTES_PER_STG
=
BYTES_PER_ELEMENT
*
8
};
#endif
// The number of MMAs in the M dimension.
// The number of MMAs in the M dimension.
enum
{
MMAS_M
=
Mma_tile
::
MMAS_M
};
enum
{
MMAS_M
=
Mma_tile
::
MMAS_M
};
// The number of MMAs in the N dimension.
// The number of MMAs in the N dimension.
...
@@ -369,6 +400,14 @@ struct Gmem_tile_mma_s : public Base {
...
@@ -369,6 +400,14 @@ struct Gmem_tile_mma_s : public Base {
for
(
int
mi
=
0
;
mi
<
M
;
mi
++
)
{
for
(
int
mi
=
0
;
mi
<
M
;
mi
++
)
{
#pragma unroll
#pragma unroll
for
(
int
ni
=
0
;
ni
<
N
;
ni
++
)
{
for
(
int
ni
=
0
;
ni
<
N
;
ni
++
)
{
#if defined(__HIP_PLATFORM_HCC__)
uint2
dst
;
dst
.
x
=
float2_to_half2
(
frag
[
ni
][
mi
].
reg
(
0
),
frag
[
ni
][
mi
].
reg
(
1
));
dst
.
y
=
float2_to_half2
(
frag
[
ni
][
mi
].
reg
(
2
),
frag
[
ni
][
mi
].
reg
(
3
));
if
(
mask
.
any_valid
(
mi
,
ni
)
)
{
Base
::
store
(
dst
,
mi
,
ni
);
}
#else
uint4
dst
;
uint4
dst
;
dst
.
x
=
frag
[
ni
][
mi
].
reg
(
0
);
dst
.
x
=
frag
[
ni
][
mi
].
reg
(
0
);
dst
.
y
=
frag
[
ni
][
mi
].
reg
(
2
);
dst
.
y
=
frag
[
ni
][
mi
].
reg
(
2
);
...
@@ -377,6 +416,7 @@ struct Gmem_tile_mma_s : public Base {
...
@@ -377,6 +416,7 @@ struct Gmem_tile_mma_s : public Base {
if
(
mask
.
any_valid
(
mi
,
ni
)
)
{
if
(
mask
.
any_valid
(
mi
,
ni
)
)
{
Base
::
store
(
dst
,
mi
,
ni
);
Base
::
store
(
dst
,
mi
,
ni
);
}
}
#endif
}
}
}
}
}
}
...
...
apex/contrib/csrc/fmha/src/fmha/mask.h
View file @
7eed2594
...
@@ -47,11 +47,28 @@ struct Mask {
...
@@ -47,11 +47,28 @@ struct Mask {
// find the warp in the Cta tile
// find the warp in the Cta tile
const
int
warp_n
=
(
warp
/
Cta_tile
::
WARPS_M
);
const
int
warp_n
=
(
warp
/
Cta_tile
::
WARPS_M
);
const
int
warp_m
=
(
warp
%
Cta_tile
::
WARPS_M
);
const
int
warp_m
=
(
warp
%
Cta_tile
::
WARPS_M
);
#if defined(__HIP_PLATFORM_HCC__)
// decompose warp into 16x16 tile
const
int
quad
=
lane
%
16
;
const
int
tid
=
lane
/
16
;
row
=
warp_m
*
16
+
quad
;
col
=
warp_n
*
16
+
tid
;
#else
// decompose warp into 8x4 tile
// decompose warp into 8x4 tile
const
int
quad
=
lane
/
4
;
const
int
quad
=
lane
/
4
;
const
int
tid
=
(
lane
%
4
)
*
2
;
const
int
tid
=
(
lane
%
4
)
*
2
;
row
=
warp_m
*
16
+
quad
;
row
=
warp_m
*
16
+
quad
;
col
=
warp_n
*
16
+
tid
;
col
=
warp_n
*
16
+
tid
;
#endif
}
inline
__device__
bool
is_valid
(
const
int
mi
,
const
int
ni
,
const
int
jj
)
const
{
// jj iterate over the 1x4 fragment
const
bool
col_valid
=
(
ni
*
Mma_tile
::
N_PER_MMA_PER_CTA
+
col
+
4
*
jj
)
<
actual_seqlen
;
//&& (row + mi * Mma_tile::M_PER_MMA_PER_CTA + ii * 8) < actual_seqlen;
return
col_valid
;
// return row_valid && col_valid;
}
}
inline
__device__
bool
is_valid
(
const
int
mi
,
const
int
ni
,
const
int
ii
,
const
int
jj
)
const
{
inline
__device__
bool
is_valid
(
const
int
mi
,
const
int
ni
,
const
int
ii
,
const
int
jj
)
const
{
...
@@ -65,7 +82,11 @@ struct Mask {
...
@@ -65,7 +82,11 @@ struct Mask {
//BERT Mask: if upper left is invalid, none are valid
//BERT Mask: if upper left is invalid, none are valid
inline
__device__
bool
any_valid
(
int
mi
,
int
ni
)
const
{
inline
__device__
bool
any_valid
(
int
mi
,
int
ni
)
const
{
#if defined(__HIP_PLATFORM_HCC__)
return
is_valid
(
mi
,
ni
,
0
);
#else
return
is_valid
(
mi
,
ni
,
0
,
0
);
return
is_valid
(
mi
,
ni
,
0
,
0
);
#endif
}
}
inline
__device__
void
load
(
int
it
)
{
inline
__device__
void
load
(
int
it
)
{
...
...
apex/contrib/csrc/fmha/src/fmha/smem_tile.h
View file @
7eed2594
...
@@ -69,7 +69,7 @@ struct Smem_tile_without_skews {
...
@@ -69,7 +69,7 @@ struct Smem_tile_without_skews {
// The number of bytes per row without packing of rows.
// The number of bytes per row without packing of rows.
enum
{
BYTES_PER_ROW_BEFORE_PACKING
=
N_WITH_PADDING
*
BITS_PER_ELEMENT
/
8
};
enum
{
BYTES_PER_ROW_BEFORE_PACKING
=
N_WITH_PADDING
*
BITS_PER_ELEMENT
/
8
};
// The number of bytes per row -- we want at least 128B per row.
// The number of bytes per row -- we want at least 128B per row.
enum
{
BYTES_PER_ROW
=
Max
<
BYTES_PER_ROW_BEFORE_PACKING
,
128
>::
VALUE
};
enum
{
BYTES_PER_ROW
=
Max
<
BYTES_PER_ROW_BEFORE_PACKING
,
128
>::
VALUE
+
4
};
// The number of rows in shared memory (two rows may be packed into a single one).
// The number of rows in shared memory (two rows may be packed into a single one).
enum
{
ROWS
=
M_
*
BYTES_PER_ROW_BEFORE_PACKING
/
BYTES_PER_ROW
};
enum
{
ROWS
=
M_
*
BYTES_PER_ROW_BEFORE_PACKING
/
BYTES_PER_ROW
};
...
@@ -117,6 +117,18 @@ struct Smem_tile_without_skews {
...
@@ -117,6 +117,18 @@ struct Smem_tile_without_skews {
inline
__device__
Smem_tile_without_skews
(
void
*
smem
,
int
tidx
)
inline
__device__
Smem_tile_without_skews
(
void
*
smem
,
int
tidx
)
:
smem_
(
__nvvm_get_smem_pointer
(
smem
))
{
:
smem_
(
__nvvm_get_smem_pointer
(
smem
))
{
#if defined (__HIP_PLATFORM_HCC__)
int
smem_write_row
=
tidx
/
THREADS_PER_ROW
;
int
smem_write_col
=
tidx
%
THREADS_PER_ROW
;
// The offset.
this
->
smem_write_offset_
=
smem_write_row
*
BYTES_PER_ROW
+
smem_write_col
*
BYTES_PER_STS
;
// TODO: Why not merge it with the read offset?
this
->
smem_read_buffer_
=
__shfl
(
0
,
0
);
this
->
smem_write_buffer_
=
__shfl
(
0
,
0
);
#else
// The row written by a thread. See doc/mma_smem_layout.xlsx.
// The row written by a thread. See doc/mma_smem_layout.xlsx.
int
smem_write_row
=
tidx
/
THREADS_PER_ROW
;
int
smem_write_row
=
tidx
/
THREADS_PER_ROW
;
...
@@ -129,10 +141,6 @@ struct Smem_tile_without_skews {
...
@@ -129,10 +141,6 @@ struct Smem_tile_without_skews {
this
->
smem_write_offset_
=
smem_write_row
*
BYTES_PER_ROW
+
smem_write_col
*
BYTES_PER_STS
;
this
->
smem_write_offset_
=
smem_write_row
*
BYTES_PER_ROW
+
smem_write_col
*
BYTES_PER_STS
;
// TODO: Why not merge it with the read offset?
// TODO: Why not merge it with the read offset?
#if defined (__HIP_PLATFORM_HCC__)
this
->
smem_read_buffer_
=
__shfl
(
0
,
0
);
this
->
smem_write_buffer_
=
__shfl
(
0
,
0
);
#else
this
->
smem_read_buffer_
=
__shfl_sync
(
0xffffffff
,
0
,
0
);
this
->
smem_read_buffer_
=
__shfl_sync
(
0xffffffff
,
0
,
0
);
this
->
smem_write_buffer_
=
__shfl_sync
(
0xffffffff
,
0
,
0
);
this
->
smem_write_buffer_
=
__shfl_sync
(
0xffffffff
,
0
,
0
);
#endif
#endif
...
@@ -259,6 +267,32 @@ struct Smem_tile_without_skews {
...
@@ -259,6 +267,32 @@ struct Smem_tile_without_skews {
uint32_t
smem_ptrs
[
N
];
uint32_t
smem_ptrs
[
N
];
this
->
compute_store_pointers
(
smem_ptrs
);
this
->
compute_store_pointers
(
smem_ptrs
);
sts
(
smem_ptrs
,
data
);
sts
(
smem_ptrs
,
data
);
// wangaq debug
// if (blockIdx.x == 0) {
// extern __shared__ char smem[];
// uint32_t base = __nvvm_get_smem_pointer(smem);
// for (int ii = 0; ii < N; ++ii) {
// printf("data tid:%d N:%d ii:%d %f %f %f %f %f %f %f %f\n", threadIdx.x, N, ii,
// __half2float(reinterpret_cast<const __half*>(&data[ii])[0]),
// __half2float(reinterpret_cast<const __half*>(&data[ii])[1]),
// __half2float(reinterpret_cast<const __half*>(&data[ii])[2]),
// __half2float(reinterpret_cast<const __half*>(&data[ii])[3]),
// __half2float(reinterpret_cast<const __half*>(&data[ii])[4]),
// __half2float(reinterpret_cast<const __half*>(&data[ii])[5]),
// __half2float(reinterpret_cast<const __half*>(&data[ii])[6]),
// __half2float(reinterpret_cast<const __half*>(&data[ii])[7]));
// __half * smem_ptr = reinterpret_cast<__half*>(smem-base+smem_ptrs[ii]);
// printf("smem_ptrs tid:%d N:%d ii:%d %f %f %f %f %f %f %f %f\n", threadIdx.x, N, ii,
// __half2float(smem_ptr[0]),
// __half2float(smem_ptr[1]),
// __half2float(smem_ptr[2]),
// __half2float(smem_ptr[3]),
// __half2float(smem_ptr[4]),
// __half2float(smem_ptr[5]),
// __half2float(smem_ptr[6]),
// __half2float(smem_ptr[7]));
// }
// }
}
}
// Store to the tile in shared memory.
// Store to the tile in shared memory.
...
@@ -408,17 +442,28 @@ struct Smem_tile_row_a : public Smem_tile_without_skews<Cta_tile,
...
@@ -408,17 +442,28 @@ struct Smem_tile_row_a : public Smem_tile_without_skews<Cta_tile,
const
int
WARPS_K
=
Cta_tile
::
WARPS_K
;
const
int
WARPS_K
=
Cta_tile
::
WARPS_K
;
static_assert
(
WARPS_M
==
1
);
static_assert
(
WARPS_M
==
1
);
#if defined (__HIP_PLATFORM_HCC__)
static_assert
(
WARPS_N
==
2
||
WARPS_N
==
4
);
#else
static_assert
(
WARPS_N
==
4
||
WARPS_N
==
8
);
static_assert
(
WARPS_N
==
4
||
WARPS_N
==
8
);
#endif
static_assert
(
WARPS_K
==
1
);
static_assert
(
WARPS_K
==
1
);
static_assert
(
Base
::
ROWS_PER_XOR_PATTERN
==
8
);
static_assert
(
Base
::
ROWS_PER_XOR_PATTERN
==
8
);
// The row and column read by the thread.
// The row and column read by the thread.
#if defined(__HIP_PLATFORM_HCC__)
const
int
M_PER_MMA_PER_CTA
=
Mma_tile
::
M_PER_MMA_PER_CTA
;
int
smem_read_row
=
(
tidx
%
M_PER_MMA_PER_CTA
);
int
smem_read_col
=
((
tidx
&
0x3f
)
/
M_PER_MMA_PER_CTA
);
this
->
smem_read_offset_
=
smem_read_row
*
Base
::
BYTES_PER_ROW
+
smem_read_col
*
(
Base
::
BITS_PER_ELEMENT
/
8
);
#else
int
smem_read_row
=
(
tidx
&
0x0f
);
int
smem_read_row
=
(
tidx
&
0x0f
);
int
smem_read_col
=
(
tidx
&
0x07
);
int
smem_read_col
=
(
tidx
&
0x07
);
smem_read_col
^=
(
tidx
&
0x10
)
/
16
;
smem_read_col
^=
(
tidx
&
0x10
)
/
16
;
// The shared memory offset.
// The shared memory offset.
this
->
smem_read_offset_
=
smem_read_row
*
Base
::
BYTES_PER_ROW
+
smem_read_col
*
BYTES_PER_LDS
;
this
->
smem_read_offset_
=
smem_read_row
*
Base
::
BYTES_PER_ROW
+
smem_read_col
*
BYTES_PER_LDS
;
#endif
}
}
// Rewind smem_read_offset for last LDS phase in main loop.
// Rewind smem_read_offset for last LDS phase in main loop.
...
@@ -437,6 +482,21 @@ struct Smem_tile_row_a : public Smem_tile_without_skews<Cta_tile,
...
@@ -437,6 +482,21 @@ struct Smem_tile_row_a : public Smem_tile_without_skews<Cta_tile,
// Jump by as many matrix rows as needed (a row in smem may pack multiple matrix rows).
// Jump by as many matrix rows as needed (a row in smem may pack multiple matrix rows).
int
offset
=
mi
*
Mma_tile
::
M_PER_MMA_PER_CTA
*
Base
::
BYTES_PER_ROW_BEFORE_PACKING
;
int
offset
=
mi
*
Mma_tile
::
M_PER_MMA_PER_CTA
*
Base
::
BYTES_PER_ROW_BEFORE_PACKING
;
#if defined(__HIP_PLATFORM_HCC__)
int
k_offset
=
4
/* 指令的k维度为4 */
*
(
Base
::
BITS_PER_ELEMENT
/
8
);
int
ki_offset
=
ki
*
Mma_tile
::
K_PER_MMA
*
(
Base
::
BITS_PER_ELEMENT
/
8
);
ldsm
(
a
[
mi
].
reg
(
0
),
this
->
smem_
+
this
->
smem_read_offset_
+
this
->
smem_read_buffer_
+
offset
+
0
*
k_offset
+
ki_offset
);
ldsm
(
a
[
mi
].
reg
(
1
),
this
->
smem_
+
this
->
smem_read_offset_
+
this
->
smem_read_buffer_
+
offset
+
1
*
k_offset
+
ki_offset
);
ldsm
(
a
[
mi
].
reg
(
2
),
this
->
smem_
+
this
->
smem_read_offset_
+
this
->
smem_read_buffer_
+
offset
+
2
*
k_offset
+
ki_offset
);
ldsm
(
a
[
mi
].
reg
(
3
),
this
->
smem_
+
this
->
smem_read_offset_
+
this
->
smem_read_buffer_
+
offset
+
3
*
k_offset
+
ki_offset
);
// if (blockIdx.x == 0) {
// printf("smem a load tid:%d %f %f %f %f\n", threadIdx.x,
// a[mi].template elt_as<float>(0),
// a[mi].template elt_as<float>(1),
// a[mi].template elt_as<float>(2),
// a[mi].template elt_as<float>(3));
// }
#else
// Load using LDSM.M88.4.
// Load using LDSM.M88.4.
uint4
tmp
;
uint4
tmp
;
ldsm
(
tmp
,
this
->
smem_
+
this
->
smem_read_offset_
+
this
->
smem_read_buffer_
+
offset
);
ldsm
(
tmp
,
this
->
smem_
+
this
->
smem_read_offset_
+
this
->
smem_read_buffer_
+
offset
);
...
@@ -446,8 +506,12 @@ struct Smem_tile_row_a : public Smem_tile_without_skews<Cta_tile,
...
@@ -446,8 +506,12 @@ struct Smem_tile_row_a : public Smem_tile_without_skews<Cta_tile,
a
[
mi
].
reg
(
1
)
=
tmp
.
y
;
a
[
mi
].
reg
(
1
)
=
tmp
.
y
;
a
[
mi
].
reg
(
2
)
=
tmp
.
z
;
a
[
mi
].
reg
(
2
)
=
tmp
.
z
;
a
[
mi
].
reg
(
3
)
=
tmp
.
w
;
a
[
mi
].
reg
(
3
)
=
tmp
.
w
;
#endif
}
}
#if defined(__HIP_PLATFORM_HCC__)
// this->smem_read_offset_ = (ki+1) * Mma_tile::K_PER_MMA * (Base::BITS_PER_ELEMENT/8);
#else
// Move the offset to the next possition. See doc/mma_smem_layout.xlsx.
// Move the offset to the next possition. See doc/mma_smem_layout.xlsx.
static_assert
(
Mma_tile_with_padding
::
MMAS_K
<
64
,
"Not implemented"
);
static_assert
(
Mma_tile_with_padding
::
MMAS_K
<
64
,
"Not implemented"
);
if
(
Mma_tile_with_padding
::
MMAS_K
>=
32
&&
ki
%
16
==
15
)
{
if
(
Mma_tile_with_padding
::
MMAS_K
>=
32
&&
ki
%
16
==
15
)
{
...
@@ -461,6 +525,7 @@ struct Smem_tile_row_a : public Smem_tile_without_skews<Cta_tile,
...
@@ -461,6 +525,7 @@ struct Smem_tile_row_a : public Smem_tile_without_skews<Cta_tile,
}
else
if
(
Mma_tile_with_padding
::
MMAS_K
>=
2
)
{
}
else
if
(
Mma_tile_with_padding
::
MMAS_K
>=
2
)
{
this
->
smem_read_offset_
^=
1
*
BYTES_PER_LDS
*
2
;
this
->
smem_read_offset_
^=
1
*
BYTES_PER_LDS
*
2
;
}
}
#endif
}
}
// Reset the read offset.
// Reset the read offset.
...
@@ -593,9 +658,15 @@ struct Smem_tile_col_b : public Smem_tile_without_skews<Cta_tile,
...
@@ -593,9 +658,15 @@ struct Smem_tile_col_b : public Smem_tile_without_skews<Cta_tile,
const
int
WARPS_K
=
Cta_tile
::
WARPS_K
;
const
int
WARPS_K
=
Cta_tile
::
WARPS_K
;
static_assert
(
Base
::
ROWS_PER_XOR_PATTERN
==
8
);
static_assert
(
Base
::
ROWS_PER_XOR_PATTERN
==
8
);
static_assert
(
WARPS_M
==
1
);
static_assert
(
WARPS_M
==
1
);
static_assert
(
WARPS_N
==
4
||
WARPS_N
==
8
);
static_assert
(
WARPS_N
==
2
||
WARPS_N
==
4
||
WARPS_N
==
8
);
static_assert
(
WARPS_K
==
1
);
static_assert
(
WARPS_K
==
1
);
#if defined(__HIP_PLATFORM_HCC__)
const
int
N_PER_MMA
=
Mma_tile
::
N_PER_MMA
;
int
smem_read_row
=
(
tidx
%
N_PER_MMA
)
+
tidx
/
Cta_tile
::
THREADS_PER_WARP
*
N_PER_MMA
;
int
smem_read_col
=
(
tidx
/
N_PER_MMA
)
%
4
;
// 指令的k维度为4
this
->
smem_read_offset_
=
smem_read_row
*
Base
::
BYTES_PER_ROW
+
smem_read_col
*
(
Base
::
BITS_PER_ELEMENT
/
8
);
#else
// The masks to select the warps.
// The masks to select the warps.
const
int
WARP_MASK_N
=
Warp_masks
<
WARPS_M
,
WARPS_N
,
WARPS_K
>::
N
;
const
int
WARP_MASK_N
=
Warp_masks
<
WARPS_M
,
WARPS_N
,
WARPS_K
>::
N
;
...
@@ -610,6 +681,7 @@ struct Smem_tile_col_b : public Smem_tile_without_skews<Cta_tile,
...
@@ -610,6 +681,7 @@ struct Smem_tile_col_b : public Smem_tile_without_skews<Cta_tile,
smem_read_col
^=
(
tidx
&
0x08
)
/
8
;
smem_read_col
^=
(
tidx
&
0x08
)
/
8
;
// The shared memory offset.
// The shared memory offset.
this
->
smem_read_offset_
=
smem_read_row
*
Base
::
BYTES_PER_ROW
+
smem_read_col
*
BYTES_PER_LDS
;
this
->
smem_read_offset_
=
smem_read_row
*
Base
::
BYTES_PER_ROW
+
smem_read_col
*
BYTES_PER_LDS
;
#endif
}
}
// Rewind smem_read_offset for last LDS phase in main loop.
// Rewind smem_read_offset for last LDS phase in main loop.
...
@@ -628,6 +700,21 @@ struct Smem_tile_col_b : public Smem_tile_without_skews<Cta_tile,
...
@@ -628,6 +700,21 @@ struct Smem_tile_col_b : public Smem_tile_without_skews<Cta_tile,
// Jump by as many matrix rows as needed (a row in smem may pack multiple matrix rows).
// Jump by as many matrix rows as needed (a row in smem may pack multiple matrix rows).
int
offset
=
ni
*
Mma_tile
::
N_PER_MMA_PER_CTA
*
Base
::
BYTES_PER_ROW_BEFORE_PACKING
;
int
offset
=
ni
*
Mma_tile
::
N_PER_MMA_PER_CTA
*
Base
::
BYTES_PER_ROW_BEFORE_PACKING
;
#if defined(__HIP_PLATFORM_HCC__)
int
k_offset
=
4
/* 指令的k维度为4 */
*
(
Base
::
BITS_PER_ELEMENT
/
8
);
int
ki_offset
=
ki
*
Mma_tile
::
K_PER_MMA_PER_CTA
*
(
Base
::
BITS_PER_ELEMENT
/
8
);
ldsm
(
b
[
ni
].
reg
(
0
),
this
->
smem_
+
this
->
smem_read_offset_
+
this
->
smem_read_buffer_
+
offset
+
0
*
k_offset
+
ki_offset
);
ldsm
(
b
[
ni
].
reg
(
1
),
this
->
smem_
+
this
->
smem_read_offset_
+
this
->
smem_read_buffer_
+
offset
+
1
*
k_offset
+
ki_offset
);
ldsm
(
b
[
ni
].
reg
(
2
),
this
->
smem_
+
this
->
smem_read_offset_
+
this
->
smem_read_buffer_
+
offset
+
2
*
k_offset
+
ki_offset
);
ldsm
(
b
[
ni
].
reg
(
3
),
this
->
smem_
+
this
->
smem_read_offset_
+
this
->
smem_read_buffer_
+
offset
+
3
*
k_offset
+
ki_offset
);
// if (blockIdx.x == 0) {
// printf("smem b load tid:%d %f %f %f %f\n", threadIdx.x,
// b[ni].template elt_as<float>(0),
// b[ni].template elt_as<float>(1),
// b[ni].template elt_as<float>(2),
// b[ni].template elt_as<float>(3));
// }
#else
// Load using LDSM.M88.4.
// Load using LDSM.M88.4.
uint4
tmp
;
uint4
tmp
;
ldsm
(
tmp
,
this
->
smem_
+
this
->
smem_read_offset_
+
this
->
smem_read_buffer_
+
offset
);
ldsm
(
tmp
,
this
->
smem_
+
this
->
smem_read_offset_
+
this
->
smem_read_buffer_
+
offset
);
...
@@ -637,8 +724,12 @@ struct Smem_tile_col_b : public Smem_tile_without_skews<Cta_tile,
...
@@ -637,8 +724,12 @@ struct Smem_tile_col_b : public Smem_tile_without_skews<Cta_tile,
b
[
ni
].
reg
(
1
)
=
tmp
.
y
;
b
[
ni
].
reg
(
1
)
=
tmp
.
y
;
b
[
ni
].
reg
(
2
)
=
tmp
.
z
;
b
[
ni
].
reg
(
2
)
=
tmp
.
z
;
b
[
ni
].
reg
(
3
)
=
tmp
.
w
;
b
[
ni
].
reg
(
3
)
=
tmp
.
w
;
#endif
}
}
#if defined(__HIP_PLATFORM_HCC__)
// this->smem_read_offset_ = (ki+1) * Mma_tile::K_PER_MMA_PER_CTA * (Base::BITS_PER_ELEMENT/8);
#else
// Move the offset to the next possition. See doc/mma_smem_layout.xlsx.
// Move the offset to the next possition. See doc/mma_smem_layout.xlsx.
static_assert
(
Mma_tile_with_padding
::
MMAS_K
<
64
,
"Not implemented"
);
static_assert
(
Mma_tile_with_padding
::
MMAS_K
<
64
,
"Not implemented"
);
if
(
Mma_tile_with_padding
::
MMAS_K
>=
32
&&
ki
%
16
==
15
)
{
if
(
Mma_tile_with_padding
::
MMAS_K
>=
32
&&
ki
%
16
==
15
)
{
...
@@ -652,6 +743,7 @@ struct Smem_tile_col_b : public Smem_tile_without_skews<Cta_tile,
...
@@ -652,6 +743,7 @@ struct Smem_tile_col_b : public Smem_tile_without_skews<Cta_tile,
}
else
if
(
Mma_tile_with_padding
::
MMAS_K
>=
2
)
{
}
else
if
(
Mma_tile_with_padding
::
MMAS_K
>=
2
)
{
this
->
smem_read_offset_
^=
1
*
BYTES_PER_LDS
*
2
;
this
->
smem_read_offset_
^=
1
*
BYTES_PER_LDS
*
2
;
}
}
#endif
}
}
// Reset the read offset.
// Reset the read offset.
...
@@ -919,20 +1011,41 @@ struct Smem_tile_v : public fmha::Smem_tile_without_skews<Cta_tile, Cta_tile::K,
...
@@ -919,20 +1011,41 @@ struct Smem_tile_v : public fmha::Smem_tile_without_skews<Cta_tile, Cta_tile::K,
// The row/col read by the thread.
// The row/col read by the thread.
int
read_row
,
read_col
;
int
read_row
,
read_col
;
static_assert
(
Cta_tile
::
WARPS_M
==
1
&&
Cta_tile
::
WARPS_N
==
1
&&
(
Cta_tile
::
WARPS_K
==
4
||
Cta_tile
::
WARPS_K
==
8
));
static_assert
(
Cta_tile
::
WARPS_M
==
1
&&
Cta_tile
::
WARPS_N
==
1
&&
(
Cta_tile
::
WARPS_K
==
2
||
Cta_tile
::
WARPS_K
==
4
||
Cta_tile
::
WARPS_K
==
8
));
#if defined(__HIP_PLATFORM_HCC__)
const
int
K_PER_MMA
=
Mma_tile
::
K_PER_MMA
;
read_row
=
(
tidx
/
16
)
%
4
+
(
tidx
/
Cta_tile
::
THREADS_PER_WARP
)
*
K_PER_MMA
;
read_col
=
tidx
%
16
;
// The shared memory offset.
this
->
smem_read_offset_
=
read_row
*
Base
::
BYTES_PER_ROW
+
read_col
*
(
Base
::
BITS_PER_ELEMENT
/
8
);
#else
read_row
=
(
tidx
&
0xe0
)
/
2
+
(
tidx
&
0x0f
);
read_row
=
(
tidx
&
0xe0
)
/
2
+
(
tidx
&
0x0f
);
read_col
=
(
tidx
&
0x07
);
read_col
=
(
tidx
&
0x07
);
read_col
^=
(
tidx
&
0x10
)
/
16
;
read_col
^=
(
tidx
&
0x10
)
/
16
;
// The shared memory offset.
// The shared memory offset.
this
->
smem_read_offset_
=
read_row
*
Base
::
BYTES_PER_ROW
+
read_col
*
BYTES_PER_LDS
;
this
->
smem_read_offset_
=
read_row
*
Base
::
BYTES_PER_ROW
+
read_col
*
BYTES_PER_LDS
;
#endif
}
}
// Load from shared memory.
// Load from shared memory.
inline
__device__
void
load
(
Fragment
(
&
b
)[
Mma_tile
::
MMAS_N
],
int
ki
)
{
inline
__device__
void
load
(
Fragment
(
&
b
)[
Mma_tile
::
MMAS_N
],
int
ki
)
{
#pragma unroll
#pragma unroll
for
(
int
ni
=
0
;
ni
<
Mma_tile
::
MMAS_N
;
++
ni
)
{
for
(
int
ni
=
0
;
ni
<
Mma_tile
::
MMAS_N
;
++
ni
)
{
#if defined(__HIP_PLATFORM_HCC__)
// Jump by 16 * #warps row.
// int row = ki * Cta_tile::K_PER_MMA_PER_CTA;
int
col_offset
=
ni
*
Mma_tile
::
N_PER_MMA
*
(
Base
::
BITS_PER_ELEMENT
/
8
)
;
int
k_offset
=
4
/* 指令的k维度为4 */
*
Base
::
BYTES_PER_ROW
;
int
ki_offset
=
ki
*
Mma_tile
::
K_PER_MMA_PER_CTA
*
Base
::
BYTES_PER_ROW
;
ldsm
(
b
[
ni
].
reg
(
0
),
this
->
smem_
+
this
->
smem_read_offset_
+
col_offset
+
0
*
k_offset
+
ki_offset
);
ldsm
(
b
[
ni
].
reg
(
1
),
this
->
smem_
+
this
->
smem_read_offset_
+
col_offset
+
1
*
k_offset
+
ki_offset
);
ldsm
(
b
[
ni
].
reg
(
2
),
this
->
smem_
+
this
->
smem_read_offset_
+
col_offset
+
2
*
k_offset
+
ki_offset
);
ldsm
(
b
[
ni
].
reg
(
3
),
this
->
smem_
+
this
->
smem_read_offset_
+
col_offset
+
3
*
k_offset
+
ki_offset
);
#else
// Jump by 16 * #warps row.
// Jump by 16 * #warps row.
int
row
=
ki
*
16
*
Cta_tile
::
WARPS_K
;
int
row
=
ki
*
16
*
Cta_tile
::
WARPS_K
;
...
@@ -950,6 +1063,7 @@ struct Smem_tile_v : public fmha::Smem_tile_without_skews<Cta_tile, Cta_tile::K,
...
@@ -950,6 +1063,7 @@ struct Smem_tile_v : public fmha::Smem_tile_without_skews<Cta_tile, Cta_tile::K,
}
else
{
}
else
{
assert
(
false
);
// Not implemented!
assert
(
false
);
// Not implemented!
}
}
#endif
}
}
}
}
};
};
...
@@ -1010,8 +1124,25 @@ struct Smem_tile_o {
...
@@ -1010,8 +1124,25 @@ struct Smem_tile_o {
// Get a 32-bit value for the shared memory address.
// Get a 32-bit value for the shared memory address.
uint32_t
smem_
=
__nvvm_get_smem_pointer
(
smem
);
uint32_t
smem_
=
__nvvm_get_smem_pointer
(
smem
);
static_assert
(
Cta_tile
::
WARPS_M
==
1
&&
Cta_tile
::
WARPS_N
==
1
&&
(
Cta_tile
::
WARPS_K
==
4
||
Cta_tile
::
WARPS_K
==
8
));
static_assert
(
Cta_tile
::
WARPS_M
==
1
&&
Cta_tile
::
WARPS_N
==
1
&&
(
Cta_tile
::
WARPS_K
==
2
||
Cta_tile
::
WARPS_K
==
4
||
Cta_tile
::
WARPS_K
==
8
));
#if defined(__HIP_PLATFORM_HCC__)
int
write_row
=
tidx
%
16
;
int
write_col
=
(
tidx
/
16
)
%
4
+
(
tidx
/
64
)
*
Mma_tile
::
K_PER_MMA
*
Mma_tile
::
MMAS_N
;
// Assemble the write pointer.
smem_write_
=
smem_
+
write_row
*
BYTES_PER_ROW
+
write_col
*
BYTES_PER_ELEMENT
;
// The element read by each thread.
int
read_row
=
tidx
/
THREADS_PER_ROW
;
int
read_col
=
tidx
%
THREADS_PER_ROW
;
// Take the XOR pattern into account for the column.
// read_col ^= 2 * (read_row & 0x7);
// Assemble the read pointer.
this
->
smem_read_
=
smem_
+
read_row
*
BYTES_PER_ROW
+
read_col
*
BYTES_PER_LDS
;
#else
int
write_row
=
(
tidx
&
0x1c
)
/
4
;
int
write_row
=
(
tidx
&
0x1c
)
/
4
;
int
write_col
=
(
tidx
);
int
write_col
=
(
tidx
);
...
@@ -1027,6 +1158,7 @@ struct Smem_tile_o {
...
@@ -1027,6 +1158,7 @@ struct Smem_tile_o {
// Assemble the read pointer.
// Assemble the read pointer.
this
->
smem_read_
=
smem_
+
read_row
*
BYTES_PER_ROW
+
read_col
*
BYTES_PER_LDS
;
this
->
smem_read_
=
smem_
+
read_row
*
BYTES_PER_ROW
+
read_col
*
BYTES_PER_LDS
;
#endif
// Is that thread active on the last LDS?
// Is that thread active on the last LDS?
if
(
HAS_INCOMPLETE_LDS
)
{
if
(
HAS_INCOMPLETE_LDS
)
{
...
@@ -1036,6 +1168,34 @@ struct Smem_tile_o {
...
@@ -1036,6 +1168,34 @@ struct Smem_tile_o {
// Load the output fragments.
// Load the output fragments.
inline
__device__
void
load
(
uint4
(
&
out
)[
LDS_PER_LOOP
])
const
{
inline
__device__
void
load
(
uint4
(
&
out
)[
LDS_PER_LOOP
])
const
{
#if defined(__HIP_PLATFORM_HCC__)
for
(
int
ii
=
0
;
ii
<
LDS_PER_LOOP
;
++
ii
)
{
// Load the elements before the reduction (split-K).
uint4
tmp
[
Cta_tile
::
WARPS_K
];
#pragma unroll
for
(
int
jj
=
0
;
jj
<
Cta_tile
::
WARPS_K
;
++
jj
)
{
int
imm
=
ii
*
ROWS_PER_LDS
*
BYTES_PER_ROW
+
jj
*
Cta_tile
::
N
*
BYTES_PER_ELEMENT
;
if
(
!
HAS_INCOMPLETE_LDS
||
(
ii
<
LDS_PER_LOOP
-
1
||
this
->
is_active_for_last_lds_
)
)
{
fmha
::
lds
(
tmp
[
jj
],
this
->
smem_read_
+
imm
);
// wangaq debug
// float * xxx = reinterpret_cast<float*>(&tmp[0]);
// xxx[0] = threadIdx.x * 10.0 + 0;
// xxx[1] = threadIdx.x * 10.0 + 1;
// xxx[2] = threadIdx.x * 10.0 + 2;
// xxx[3] = threadIdx.x * 10.0 + 3;
// fmha::sts(this->smem_read_ + imm, tmp[0]);
}
}
// Perform the reduction.
out
[
ii
]
=
tmp
[
0
];
#pragma unroll
for
(
int
jj
=
1
;
jj
<
Cta_tile
::
WARPS_K
;
++
jj
)
{
out
[
ii
]
=
fmha
::
fadd4
(
out
[
ii
],
tmp
[
jj
]);
}
}
#else
#pragma unroll
#pragma unroll
for
(
int
ii
=
0
;
ii
<
LDS_PER_LOOP
;
++
ii
)
{
for
(
int
ii
=
0
;
ii
<
LDS_PER_LOOP
;
++
ii
)
{
...
@@ -1056,10 +1216,31 @@ struct Smem_tile_o {
...
@@ -1056,10 +1216,31 @@ struct Smem_tile_o {
out
[
ii
]
=
fmha
::
fadd4
(
out
[
ii
],
tmp
[
jj
]);
out
[
ii
]
=
fmha
::
fadd4
(
out
[
ii
],
tmp
[
jj
]);
}
}
}
}
#endif
}
}
// Store the accumulators.
// Store the accumulators.
template
<
int
M
,
int
N
>
template
<
int
M
,
int
N
>
inline
__device__
void
store
(
const
Accumulator
(
&
acc
)[
M
][
N
],
int
mi
)
{
inline
__device__
void
store
(
const
Accumulator
(
&
acc
)[
M
][
N
],
int
mi
)
{
#if defined(__HIP_PLATFORM_HCC__)
for
(
int
mi
=
0
;
mi
<
M
;
++
mi
)
{
for
(
int
ni
=
0
;
ni
<
N
;
++
ni
)
{
int
ni_offset
=
Mma_tile
::
K_PER_MMA
*
ni
*
BYTES_PER_ELEMENT
;
// uint32_t tmp[4];
// reinterpret_cast<float&>(tmp[0]) = threadIdx.x * 100.0 + ni * 10 + 0;
// reinterpret_cast<float&>(tmp[1]) = threadIdx.x * 100.0 + ni * 10 + 1;
// reinterpret_cast<float&>(tmp[2]) = threadIdx.x * 100.0 + ni * 10 + 2;
// reinterpret_cast<float&>(tmp[3]) = threadIdx.x * 100.0 + ni * 10 + 3;
// fmha::sts(this->smem_write_ + ni_offset + 0 * BYTES_PER_ELEMENT, tmp[0]);
// fmha::sts(this->smem_write_ + ni_offset + 4 * BYTES_PER_ELEMENT, tmp[1]);
// fmha::sts(this->smem_write_ + ni_offset + 8 * BYTES_PER_ELEMENT, tmp[2]);
// fmha::sts(this->smem_write_ + ni_offset + 12 * BYTES_PER_ELEMENT, tmp[3]);
fmha
::
sts
(
this
->
smem_write_
+
ni_offset
+
0
*
BYTES_PER_ELEMENT
,
acc
[
mi
][
ni
].
reg
(
0
));
fmha
::
sts
(
this
->
smem_write_
+
ni_offset
+
4
*
BYTES_PER_ELEMENT
,
acc
[
mi
][
ni
].
reg
(
1
));
fmha
::
sts
(
this
->
smem_write_
+
ni_offset
+
8
*
BYTES_PER_ELEMENT
,
acc
[
mi
][
ni
].
reg
(
2
));
fmha
::
sts
(
this
->
smem_write_
+
ni_offset
+
12
*
BYTES_PER_ELEMENT
,
acc
[
mi
][
ni
].
reg
(
3
));
}
}
#else
enum
{
M_PER_MMA
=
Mma_tile
::
M_PER_MMA_PER_CTA
};
enum
{
M_PER_MMA
=
Mma_tile
::
M_PER_MMA_PER_CTA
};
#pragma unroll
#pragma unroll
for
(
int
ni
=
0
;
ni
<
Mma_tile
::
MMAS_N
;
++
ni
)
{
for
(
int
ni
=
0
;
ni
<
Mma_tile
::
MMAS_N
;
++
ni
)
{
...
@@ -1109,6 +1290,7 @@ struct Smem_tile_o {
...
@@ -1109,6 +1290,7 @@ struct Smem_tile_o {
// Cancel the previous XOR of 1 + swizzle the write pointer using a XOR of 32B or 64B.
// Cancel the previous XOR of 1 + swizzle the write pointer using a XOR of 32B or 64B.
this
->
smem_write_
^=
(
ni
&
1
)
?
7
*
32
:
3
*
32
;
this
->
smem_write_
^=
(
ni
&
1
)
?
7
*
32
:
3
*
32
;
}
}
#endif
}
}
};
};
...
@@ -1177,11 +1359,11 @@ struct Smem_tile_mma_transposed : public Base {
...
@@ -1177,11 +1359,11 @@ struct Smem_tile_mma_transposed : public Base {
enum
{
BYTES_PER_ELT
=
Base
::
BYTES_PER_ELT
};
enum
{
BYTES_PER_ELT
=
Base
::
BYTES_PER_ELT
};
enum
{
WARPS_M
=
Base
::
WARPS_M
};
enum
{
WARPS_M
=
Base
::
WARPS_M
};
enum
{
WARPS_N
=
Base
::
WARPS_N
};
enum
{
WARPS_N
=
Base
::
WARPS_N
};
static_assert
(
WARPS_M
==
1
&&
(
WARPS_N
==
4
||
WARPS_N
==
8
));
static_assert
(
WARPS_M
==
1
&&
(
WARPS_N
==
2
||
WARPS_N
==
4
||
WARPS_N
==
8
));
using
Fragment
=
typename
Base
::
Fragment
;
using
Fragment
=
typename
Base
::
Fragment
;
inline
__device__
Smem_tile_mma_transposed
(
char
*
smem
,
int
tidx
)
:
Base
(
smem
,
tidx
)
{
inline
__device__
Smem_tile_mma_transposed
(
char
*
smem
,
int
tidx
)
:
Base
(
smem
,
tidx
)
{
static_assert
(
WARPS_M
==
1
&&
(
WARPS_N
==
4
||
WARPS_N
==
8
));
static_assert
(
WARPS_M
==
1
&&
(
WARPS_N
==
2
||
WARPS_N
==
4
||
WARPS_N
==
8
));
int
read_row
,
read_col
;
int
read_row
,
read_col
;
read_row
=
(
tidx
&
0x0f
);
read_row
=
(
tidx
&
0x0f
);
read_col
=
(
tidx
&
0xe0
)
/
16
+
(
tidx
&
0x1c
)
/
16
;
read_col
=
(
tidx
&
0xe0
)
/
16
+
(
tidx
&
0x1c
)
/
16
;
...
@@ -1221,7 +1403,7 @@ struct Smem_tile_mma_epilogue : public Base {
...
@@ -1221,7 +1403,7 @@ struct Smem_tile_mma_epilogue : public Base {
static_assert
(
NUM_LDS
*
ROWS_PER_LDS
==
Cta_tile
::
M
);
static_assert
(
NUM_LDS
*
ROWS_PER_LDS
==
Cta_tile
::
M
);
enum
{
WARPS_M
=
Base
::
WARPS_M
};
enum
{
WARPS_M
=
Base
::
WARPS_M
};
enum
{
WARPS_N
=
Base
::
WARPS_N
};
enum
{
WARPS_N
=
Base
::
WARPS_N
};
static_assert
((
WARPS_M
==
4
||
WARPS_N
==
8
)
||
WARPS_N
==
1
);
static_assert
((
WARPS_N
==
2
||
WARPS_M
==
4
||
WARPS_N
==
8
)
||
WARPS_N
==
1
);
using
Acc
=
fmha
::
Fragment_accumulator
;
using
Acc
=
fmha
::
Fragment_accumulator
;
...
...
apex/contrib/csrc/fmha/src/fmha/softmax.h
View file @
7eed2594
...
@@ -56,8 +56,13 @@ inline __device__ float apply_exp_(float x, float max) {
...
@@ -56,8 +56,13 @@ inline __device__ float apply_exp_(float x, float max) {
////////////////////////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
int
COLS
>
struct
ReadType
{};
template
<
int
COLS
>
struct
ReadType
{};
#if defined(__HIP_PLATFORM_HCC__)
template
<
>
struct
ReadType
<
2
>
{
using
T
=
float
;};
template
<
>
struct
ReadType
<
4
>
{
using
T
=
float2
;};
#else
template
<
>
struct
ReadType
<
4
>
{
using
T
=
float
;};
template
<
>
struct
ReadType
<
4
>
{
using
T
=
float
;};
template
<
>
struct
ReadType
<
8
>
{
using
T
=
float2
;};
template
<
>
struct
ReadType
<
8
>
{
using
T
=
float2
;};
#endif
////////////////////////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////////////////////////
...
@@ -78,7 +83,11 @@ struct Smem_tile_reduce {
...
@@ -78,7 +83,11 @@ struct Smem_tile_reduce {
static
constexpr
int
ROWS
=
WARPS_M
*
MMAS_M
*
16
;
static
constexpr
int
ROWS
=
WARPS_M
*
MMAS_M
*
16
;
static
constexpr
int
COLS
=
WARPS_N
;
static
constexpr
int
COLS
=
WARPS_N
;
#if defined(__HIP_PLATFORM_HCC__)
static_assert
(
COLS
==
2
||
COLS
==
4
);
#else
static_assert
(
COLS
==
4
||
COLS
==
8
);
static_assert
(
COLS
==
4
||
COLS
==
8
);
#endif
static
constexpr
int
ROWS_PER_XOR_PATTERN
=
(
COLS
==
8
)
?
4
:
8
;
static
constexpr
int
ROWS_PER_XOR_PATTERN
=
(
COLS
==
8
)
?
4
:
8
;
static
constexpr
int
BYTES_PER_TILE
=
ROWS
*
COLS
*
sizeof
(
float
);
static
constexpr
int
BYTES_PER_TILE
=
ROWS
*
COLS
*
sizeof
(
float
);
static
constexpr
int
ELTS_PER_TILE
=
ROWS
*
COLS
;
static
constexpr
int
ELTS_PER_TILE
=
ROWS
*
COLS
;
...
@@ -93,6 +102,20 @@ struct Smem_tile_reduce {
...
@@ -93,6 +102,20 @@ struct Smem_tile_reduce {
__device__
inline
Smem_tile_reduce
(
float
*
smem_
,
const
int
tidx
)
{
__device__
inline
Smem_tile_reduce
(
float
*
smem_
,
const
int
tidx
)
{
#if defined(__HIP_PLATFORM_HCC__)
int
lane
=
tidx
%
Cta_tile
::
THREADS_PER_WARP
;
int
warp
=
tidx
/
Cta_tile
::
THREADS_PER_WARP
;
int
warp_m
=
warp
%
WARPS_M
;
int
warp_n
=
warp
/
WARPS_M
;
qid_
=
lane
/
16
;
// 前16个线程才能写入
int
qp
=
lane
%
16
;
const
int
col
=
warp_n
;
smem_write_
=
&
smem_
[
warp_m
*
ELTS_PER_TILE
+
qp
*
WARPS_N
+
col
];
smem_read_
=
&
reinterpret_cast
<
read_t
*>
(
smem_
)[
warp_m
*
ELTS_PER_TILE
+
qp
*
2
+
qid_
/
WARPS_N
];
#else
int
lane
=
tidx
%
32
;
int
lane
=
tidx
%
32
;
int
warp
=
tidx
/
32
;
int
warp
=
tidx
/
32
;
...
@@ -107,7 +130,17 @@ struct Smem_tile_reduce {
...
@@ -107,7 +130,17 @@ struct Smem_tile_reduce {
const
int
col
=
warp_n
^
(
qp
/
ROWS_PER_XOR_PATTERN
);
const
int
col
=
warp_n
^
(
qp
/
ROWS_PER_XOR_PATTERN
);
smem_write_
=
&
smem_
[
warp_m
*
16
*
MMAS_M
*
WARPS_N
+
qp
*
WARPS_N
+
col
];
smem_write_
=
&
smem_
[
warp_m
*
16
*
MMAS_M
*
WARPS_N
+
qp
*
WARPS_N
+
col
];
smem_read_
=
&
reinterpret_cast
<
read_t
*>
(
smem_
)[
warp_m
*
16
*
MMAS_M
*
4
+
qp
*
4
+
qid_
];
smem_read_
=
&
reinterpret_cast
<
read_t
*>
(
smem_
)[
warp_m
*
16
*
MMAS_M
*
4
+
qp
*
4
+
qid_
];
#endif
}
__device__
inline
void
store
(
float
(
&
frag
)[
MMAS_M
])
{
if
(
qid_
==
0
)
{
#pragma unroll
for
(
int
mi
=
0
;
mi
<
MMAS_M
;
mi
++
)
{
*
smem_write_
=
frag
[
mi
];
}
}
}
}
__device__
inline
void
store
(
float
(
&
frag
)[
2
*
MMAS_M
])
{
__device__
inline
void
store
(
float
(
&
frag
)[
2
*
MMAS_M
])
{
...
@@ -121,6 +154,13 @@ struct Smem_tile_reduce {
...
@@ -121,6 +154,13 @@ struct Smem_tile_reduce {
}
}
}
}
__device__
inline
void
load
(
read_t
(
&
frag
)[
MMAS_M
])
{
#pragma unroll
for
(
int
mi
=
0
;
mi
<
MMAS_M
;
mi
++
)
{
frag
[
mi
]
=
*
smem_read_
;
}
}
__device__
inline
void
load
(
read_t
(
&
frag
)[
2
*
MMAS_M
])
{
__device__
inline
void
load
(
read_t
(
&
frag
)[
2
*
MMAS_M
])
{
#pragma unroll
#pragma unroll
for
(
int
mi
=
0
;
mi
<
MMAS_M
;
mi
++
)
{
for
(
int
mi
=
0
;
mi
<
MMAS_M
;
mi
++
)
{
...
@@ -188,6 +228,29 @@ struct Softmax_base {
...
@@ -188,6 +228,29 @@ struct Softmax_base {
smem_read_
=
&
smem_
[
warp_m
*
Mma_tile
::
M_PER_MMA
+
lane
/
4
];
smem_read_
=
&
smem_
[
warp_m
*
Mma_tile
::
M_PER_MMA
+
lane
/
4
];
}
}
#if defined(__HIP_PLATFORM_HCC__)
template
<
typename
Mask
>
inline
__device__
void
apply_mask
(
const
Mask
&
mask
)
{
#pragma unroll
for
(
int
mi
=
0
;
mi
<
MMAS_M
;
++
mi
)
{
#pragma unroll
for
(
int
ni
=
0
;
ni
<
MMAS_N
;
++
ni
)
{
#pragma unroll
for
(
int
jj
=
0
;
jj
<
4
;
++
jj
)
{
if
(
!
mask
.
is_valid
(
mi
,
ni
,
jj
)
)
{
elt_
[
mi
][
4
*
ni
+
jj
]
=
-
INFINITY
;
}
}
// wangaq debug
// if (blockIdx.x == 0) {
// printf("apply_mask tid:%d mi:%d ni:%d %6.4f %6.4f %6.4f %6.4f\n", threadIdx.x, mi, ni,
// this->elt_[mi][4 * ni + 0], this->elt_[mi][4 * ni + 1], this->elt_[mi][4 * ni + 2], this->elt_[mi][4 * ni + 3]);
// }
}
}
}
#else
template
<
typename
Mask
>
template
<
typename
Mask
>
inline
__device__
void
apply_mask
(
const
Mask
&
mask
)
{
inline
__device__
void
apply_mask
(
const
Mask
&
mask
)
{
#pragma unroll
#pragma unroll
...
@@ -206,6 +269,26 @@ struct Softmax_base {
...
@@ -206,6 +269,26 @@ struct Softmax_base {
}
}
}
}
}
}
#endif
// Apply the exp to all the elements.
inline
__device__
void
apply_exp
(
const
float
(
&
max
)[
MMAS_M
])
{
#pragma unroll
for
(
int
mi
=
0
;
mi
<
MMAS_M
;
++
mi
)
{
#pragma unroll
for
(
int
ni
=
0
;
ni
<
MMAS_N
*
4
;
++
ni
)
{
elt_
[
mi
][
ni
]
=
apply_exp_
(
elt_
[
mi
][
ni
],
max
[
mi
]);
}
// wangaq debug
// if (blockIdx.x == 0) {
// for( int ni = 0; ni < MMAS_N; ++ni ) {
// printf("apply_exp tid:%d mi:%d ni:%d max:%6.4f %f %f %f %f\n", threadIdx.x, mi, ni, max[mi],
// this->elt_[mi][4 * ni + 0], this->elt_[mi][4 * ni + 1], this->elt_[mi][4 * ni + 2], this->elt_[mi][4 * ni + 3]);
// }
// }
}
}
// Apply the exp to all the elements.
// Apply the exp to all the elements.
inline
__device__
void
apply_exp
(
const
float
(
&
max
)[
MMAS_M
*
2
])
{
inline
__device__
void
apply_exp
(
const
float
(
&
max
)[
MMAS_M
*
2
])
{
...
@@ -218,6 +301,33 @@ struct Softmax_base {
...
@@ -218,6 +301,33 @@ struct Softmax_base {
}
}
}
}
// Scale all the elements.
inline
__device__
void
scale
(
const
float
(
&
sum
)[
MMAS_M
])
{
// Precompute the inverse sum to normalize. Without -use_fast_math, it makes a huge deal.
float
inv_sum
[
MMAS_M
];
#pragma unroll
for
(
int
mi
=
0
;
mi
<
MMAS_M
;
++
mi
)
{
inv_sum
[
mi
]
=
(
sum
[
mi
]
==
0.
f
||
sum
[
mi
]
!=
sum
[
mi
])
?
1.
f
:
1.
f
/
sum
[
mi
];
}
// Update the values.
#pragma unroll
for
(
int
mi
=
0
;
mi
<
MMAS_M
;
++
mi
)
{
#pragma unroll
for
(
int
ni
=
0
;
ni
<
MMAS_N
*
4
;
++
ni
)
{
elt_
[
mi
][
ni
]
*=
inv_sum
[
mi
];
}
// wangaq debug
// if (blockIdx.x == 0) {
// for( int ni = 0; ni < MMAS_N; ++ni ) {
// printf("scale tid:%d mi:%d ni:%d sum:%6.4f inv_sum:%6.4f %f %f %f %f\n", threadIdx.x, mi, ni, sum[mi], inv_sum[mi],
// this->elt_[mi][4 * ni + 0], this->elt_[mi][4 * ni + 1], this->elt_[mi][4 * ni + 2], this->elt_[mi][4 * ni + 3]);
// }
// }
}
}
// Scale all the elements.
// Scale all the elements.
inline
__device__
void
scale
(
const
float
(
&
sum
)[
MMAS_M
*
2
])
{
inline
__device__
void
scale
(
const
float
(
&
sum
)[
MMAS_M
*
2
])
{
// Precompute the inverse sum to normalize. Without -use_fast_math, it makes a huge deal.
// Precompute the inverse sum to normalize. Without -use_fast_math, it makes a huge deal.
...
@@ -244,7 +354,11 @@ struct Softmax_base {
...
@@ -244,7 +354,11 @@ struct Softmax_base {
// The current thread index.
// The current thread index.
int
tidx_
;
int
tidx_
;
// The elements.
// The elements.
#if defined(__HIP_PLATFORM_HCC__)
float
elt_
[
MMAS_M
][
MMAS_N
*
4
];
#else
float
elt_
[
MMAS_M
*
2
][
MMAS_N
*
4
];
float
elt_
[
MMAS_M
*
2
][
MMAS_N
*
4
];
#endif
};
};
////////////////////////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////////////////////////
...
@@ -290,7 +404,20 @@ struct Softmax : public Softmax_base<Cta_tile, Kernel_traits> {
...
@@ -290,7 +404,20 @@ struct Softmax : public Softmax_base<Cta_tile, Kernel_traits> {
for
(
int
mi
=
0
;
mi
<
M
;
++
mi
)
{
for
(
int
mi
=
0
;
mi
<
M
;
++
mi
)
{
#pragma unroll
#pragma unroll
for
(
int
ki
=
0
;
ki
<
K
;
++
ki
)
{
for
(
int
ki
=
0
;
ki
<
K
;
++
ki
)
{
#if defined(__HIP_PLATFORM_HCC__)
dst
[
ki
][
mi
].
template
elt_as
<
float
>(
0
)
=
this
->
elt_
[
mi
][
4
*
ki
+
0
];
dst
[
ki
][
mi
].
template
elt_as
<
float
>(
1
)
=
this
->
elt_
[
mi
][
4
*
ki
+
1
];
dst
[
ki
][
mi
].
template
elt_as
<
float
>(
2
)
=
this
->
elt_
[
mi
][
4
*
ki
+
2
];
dst
[
ki
][
mi
].
template
elt_as
<
float
>(
3
)
=
this
->
elt_
[
mi
][
4
*
ki
+
3
];
// wangaq debug
// if (blockIdx.x == 0) {
// printf("pack tid:%d mi:%d ki:%d %6.4f %6.4f %6.4f %6.4f -> %6.4f %6.4f %6.4f %6.4f\n", threadIdx.x, mi, ki,
// this->elt_[mi][4 * ki + 0], this->elt_[mi][4 * ki + 1], this->elt_[mi][4 * ki + 2], this->elt_[mi][4 * ki + 3],
// dst[ki][mi].template elt_as<float>(0), dst[ki][mi].template elt_as<float>(1), dst[ki][mi].template elt_as<float>(2), dst[ki][mi].template elt_as<float>(3));
// // printf("pack tid:%d mi:%d ki:%d %6.4f %6.4f %6.4f %6.4f\n", threadIdx.x, mi, ki,
// // dst[ki][mi].template elt_as<float>(0), dst[ki][mi].template elt_as<float>(1), dst[ki][mi].template elt_as<float>(2), dst[ki][mi].template elt_as<float>(3));
// }
#else
// 1st row - 4 elements per row.
// 1st row - 4 elements per row.
float
tmp_00
=
this
->
elt_
[
2
*
mi
+
0
][
4
*
ki
+
0
];
float
tmp_00
=
this
->
elt_
[
2
*
mi
+
0
][
4
*
ki
+
0
];
float
tmp_01
=
this
->
elt_
[
2
*
mi
+
0
][
4
*
ki
+
1
];
float
tmp_01
=
this
->
elt_
[
2
*
mi
+
0
][
4
*
ki
+
1
];
...
@@ -308,6 +435,7 @@ struct Softmax : public Softmax_base<Cta_tile, Kernel_traits> {
...
@@ -308,6 +435,7 @@ struct Softmax : public Softmax_base<Cta_tile, Kernel_traits> {
dst
[
ki
][
mi
].
reg
(
1
)
=
fmha
::
float2_to_half2
(
tmp_10
,
tmp_11
);
dst
[
ki
][
mi
].
reg
(
1
)
=
fmha
::
float2_to_half2
(
tmp_10
,
tmp_11
);
dst
[
ki
][
mi
].
reg
(
2
)
=
fmha
::
float2_to_half2
(
tmp_02
,
tmp_03
);
dst
[
ki
][
mi
].
reg
(
2
)
=
fmha
::
float2_to_half2
(
tmp_02
,
tmp_03
);
dst
[
ki
][
mi
].
reg
(
3
)
=
fmha
::
float2_to_half2
(
tmp_12
,
tmp_13
);
dst
[
ki
][
mi
].
reg
(
3
)
=
fmha
::
float2_to_half2
(
tmp_12
,
tmp_13
);
#endif
}
}
}
}
}
}
...
@@ -340,6 +468,19 @@ struct Softmax : public Softmax_base<Cta_tile, Kernel_traits> {
...
@@ -340,6 +468,19 @@ struct Softmax : public Softmax_base<Cta_tile, Kernel_traits> {
for
(
int
mi
=
0
;
mi
<
MMAS_M
;
++
mi
)
{
for
(
int
mi
=
0
;
mi
<
MMAS_M
;
++
mi
)
{
#pragma unroll
#pragma unroll
for
(
int
ni
=
0
;
ni
<
MMAS_N
;
++
ni
)
{
for
(
int
ni
=
0
;
ni
<
MMAS_N
;
++
ni
)
{
#if defined(__HIP_PLATFORM_HCC__)
this
->
elt_
[
mi
][
4
*
ni
+
0
]
=
acc
[
mi
][
ni
].
elt
(
0
);
this
->
elt_
[
mi
][
4
*
ni
+
1
]
=
acc
[
mi
][
ni
].
elt
(
1
);
this
->
elt_
[
mi
][
4
*
ni
+
2
]
=
acc
[
mi
][
ni
].
elt
(
2
);
this
->
elt_
[
mi
][
4
*
ni
+
3
]
=
acc
[
mi
][
ni
].
elt
(
3
);
// wangaq debug
// if (blockIdx.x == 0) {
// printf("unpack_noscale tid:%d mi:%d ni:%d %6.4f %6.4f %6.4f %6.4f -> %6.4f %6.4f %6.4f %6.4f\n", threadIdx.x, mi, ni,
// acc[mi][ni].template elt_as<float>(0), acc[mi][ni].template elt_as<float>(1), acc[mi][ni].template elt_as<float>(2), acc[mi][ni].template elt_as<float>(3),
// this->elt_[mi][4 * ni + 0], this->elt_[mi][4 * ni + 1], this->elt_[mi][4 * ni + 2], this->elt_[mi][4 * ni + 3]);
// }
#else
// 1st row - 4 elements per row.
// 1st row - 4 elements per row.
this
->
elt_
[
2
*
mi
+
0
][
4
*
ni
+
0
]
=
acc
[
mi
][
ni
].
elt
(
0
);
this
->
elt_
[
2
*
mi
+
0
][
4
*
ni
+
0
]
=
acc
[
mi
][
ni
].
elt
(
0
);
this
->
elt_
[
2
*
mi
+
0
][
4
*
ni
+
1
]
=
acc
[
mi
][
ni
].
elt
(
1
);
this
->
elt_
[
2
*
mi
+
0
][
4
*
ni
+
1
]
=
acc
[
mi
][
ni
].
elt
(
1
);
...
@@ -350,11 +491,39 @@ struct Softmax : public Softmax_base<Cta_tile, Kernel_traits> {
...
@@ -350,11 +491,39 @@ struct Softmax : public Softmax_base<Cta_tile, Kernel_traits> {
this
->
elt_
[
2
*
mi
+
1
][
4
*
ni
+
1
]
=
acc
[
mi
][
ni
].
elt
(
3
);
this
->
elt_
[
2
*
mi
+
1
][
4
*
ni
+
1
]
=
acc
[
mi
][
ni
].
elt
(
3
);
this
->
elt_
[
2
*
mi
+
1
][
4
*
ni
+
2
]
=
acc
[
mi
][
ni
].
elt
(
6
);
this
->
elt_
[
2
*
mi
+
1
][
4
*
ni
+
2
]
=
acc
[
mi
][
ni
].
elt
(
6
);
this
->
elt_
[
2
*
mi
+
1
][
4
*
ni
+
3
]
=
acc
[
mi
][
ni
].
elt
(
7
);
this
->
elt_
[
2
*
mi
+
1
][
4
*
ni
+
3
]
=
acc
[
mi
][
ni
].
elt
(
7
);
#endif
}
}
}
}
}
template
<
typename
Operator
>
__device__
inline
void
reduce_
(
float
(
&
frag
)[
MMAS_M
],
Operator
&
op
,
Smem_tile_red
&
smem_red
)
{
for
(
int
mi
=
0
;
mi
<
MMAS_M
;
mi
++
)
{
frag
[
mi
]
=
this
->
elt_
[
mi
][
0
];
for
(
int
ni
=
1
;
ni
<
4
*
MMAS_N
;
ni
++
)
{
frag
[
mi
]
=
op
(
frag
[
mi
],
this
->
elt_
[
mi
][
ni
]);
}
}
}
quad_reduce
(
frag
,
frag
,
op
);
smem_red
.
store
(
frag
);
__syncthreads
();
typename
Smem_tile_red
::
read_t
tmp
[
MMAS_M
];
smem_red
.
load
(
tmp
);
binary_allreduce
(
frag
,
tmp
,
op
);
}
__device__
inline
void
reduce_max
(
float
(
&
frag
)[
MMAS_M
]){
MaxOp
<
float
>
max
;
reduce_
(
frag
,
max
,
smem_max_
);
}
__device__
inline
void
reduce_sum
(
float
(
&
frag
)[
MMAS_M
]){
SumOp
<
float
>
sum
;
reduce_
(
frag
,
sum
,
smem_sum_
);
}
template
<
typename
Operator
>
template
<
typename
Operator
>
__device__
inline
void
reduce_
(
float
(
&
frag
)[
2
*
MMAS_M
],
Operator
&
op
,
Smem_tile_red
&
smem_red
)
{
__device__
inline
void
reduce_
(
float
(
&
frag
)[
2
*
MMAS_M
],
Operator
&
op
,
Smem_tile_red
&
smem_red
)
{
...
...
apex/contrib/csrc/fmha/src/fmha/utils.h
View file @
7eed2594
...
@@ -294,6 +294,13 @@ static inline __device__ uint32_t hmul2(uint32_t a, uint32_t b) {
...
@@ -294,6 +294,13 @@ static inline __device__ uint32_t hmul2(uint32_t a, uint32_t b) {
return
c
;
return
c
;
}
}
static
inline
__device__
uint32_t
fmul
(
uint32_t
a
,
uint32_t
b
)
{
uint32_t
c
;
float
tmp
=
reinterpret_cast
<
float
&>
(
a
)
*
reinterpret_cast
<
float
&>
(
b
);
c
=
reinterpret_cast
<
uint32_t
&>
(
tmp
);
return
c
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////////////////////////
static
inline
__device__
uint2
hmul4
(
uint2
a
,
uint2
b
)
{
static
inline
__device__
uint2
hmul4
(
uint2
a
,
uint2
b
)
{
...
@@ -346,6 +353,15 @@ static inline __device__ uint32_t hrelu2(uint32_t x, uint32_t lb = 0) {
...
@@ -346,6 +353,15 @@ static inline __device__ uint32_t hrelu2(uint32_t x, uint32_t lb = 0) {
#endif
#endif
return
res
;
return
res
;
}
}
static
inline
__device__
uint32_t
frelu
(
uint32_t
x
,
uint32_t
lb
=
0
)
{
uint32_t
res
;
float
tmp_x
=
reinterpret_cast
<
float
&>
(
x
);
tmp_x
=
tmp_x
>
lb
?
tmp_x
:
lb
;
__builtin_memcpy
(
&
res
,
&
tmp_x
,
sizeof
(
uint32_t
));
return
res
;
}
static
inline
__device__
uint32_t
habs2
(
uint32_t
x
)
{
static
inline
__device__
uint32_t
habs2
(
uint32_t
x
)
{
uint32_t
res
;
uint32_t
res
;
#if defined (__HIP_PLATFORM_HCC__)
#if defined (__HIP_PLATFORM_HCC__)
...
@@ -905,10 +921,19 @@ inline __device__ void lds(uint4 &dst, uint32_t ptr) {
...
@@ -905,10 +921,19 @@ inline __device__ void lds(uint4 &dst, uint32_t ptr) {
////////////////////////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
void
ldsm
(
uint32_t
&
dst
,
uint32_t
ptr
)
{
inline
__device__
void
ldsm
(
uint32_t
&
dst
,
uint32_t
ptr
)
{
#if defined (__HIP_PLATFORM_HCC__)
extern
__shared__
char
smem
[];
uint32_t
base
=
__nvvm_get_smem_pointer
(
smem
);
float
tmp
=
__half2float
(
*
(
__half
*
)(
smem
-
base
+
ptr
));
__builtin_memcpy
(
&
dst
,
&
tmp
,
sizeof
(
uint32_t
));
// if (blockIdx.x == 0)
// printf("ldsm tid:%d tmp:%f\n", threadIdx.x, tmp);
#else
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 730
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 730
asm
volatile
(
"ldmatrix.sync.aligned.m8n8.x1.shared.b16 {%0}, [%1];
\n
"
asm
volatile
(
"ldmatrix.sync.aligned.m8n8.x1.shared.b16 {%0}, [%1];
\n
"
:
"=r"
(
dst
)
:
"r"
(
ptr
));
:
"=r"
(
dst
)
:
"r"
(
ptr
));
#endif
#endif
#endif
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////////////////////////
...
@@ -1114,10 +1139,32 @@ inline __device__ void sts(uint32_t ptr, uint2 val) {
...
@@ -1114,10 +1139,32 @@ inline __device__ void sts(uint32_t ptr, uint2 val) {
inline
__device__
void
sts
(
uint32_t
ptr
,
uint4
val
)
{
inline
__device__
void
sts
(
uint32_t
ptr
,
uint4
val
)
{
#if defined (__HIP_PLATFORM_HCC__)
#if defined (__HIP_PLATFORM_HCC__)
asm
volatile
(
"ds_write_b32 %0, %1;"
:
:
"v"
(
ptr
)
,
"v"
(
val
.
x
));
// asm volatile("ds_write_b32 %0, %1;" : : "v"(ptr) , "v"(val.x));
asm
volatile
(
"ds_write_b32 %0, %1;"
:
:
"v"
(
ptr
+
4
)
,
"v"
(
val
.
y
));
// asm volatile("ds_write_b32 %0, %1;" : : "v"(ptr+4) , "v"(val.y));
asm
volatile
(
"ds_write_b32 %0, %1;"
:
:
"v"
(
ptr
+
8
)
,
"v"
(
val
.
z
));
// asm volatile("ds_write_b32 %0, %1;" : : "v"(ptr+8) , "v"(val.z));
asm
volatile
(
"ds_write_b32 %0, %1;"
:
:
"v"
(
ptr
+
12
)
,
"v"
(
val
.
w
));
// asm volatile("ds_write_b32 %0, %1;" : : "v"(ptr+12) , "v"(val.w));
extern
__shared__
char
smem
[];
uint32_t
base
=
__nvvm_get_smem_pointer
(
smem
);
__builtin_memcpy
(
smem
-
base
+
ptr
,
&
val
,
sizeof
(
uint4
));
// if (blockIdx.x == 0) {
// printf("sts tid:%d %f %f %f %f %f %f %f %f -> %f %f %f %f %f %f %f %f\n", threadIdx.x,
// __half2float(reinterpret_cast<half*>(&val)[0]),
// __half2float(reinterpret_cast<half*>(&val)[1]),
// __half2float(reinterpret_cast<half*>(&val)[2]),
// __half2float(reinterpret_cast<half*>(&val)[3]),
// __half2float(reinterpret_cast<half*>(&val)[4]),
// __half2float(reinterpret_cast<half*>(&val)[5]),
// __half2float(reinterpret_cast<half*>(&val)[6]),
// __half2float(reinterpret_cast<half*>(&val)[7]),
// __half2float(reinterpret_cast<half*>(smem-base+ptr)[0]),
// __half2float(reinterpret_cast<half*>(smem-base+ptr)[1]),
// __half2float(reinterpret_cast<half*>(smem-base+ptr)[2]),
// __half2float(reinterpret_cast<half*>(smem-base+ptr)[3]),
// __half2float(reinterpret_cast<half*>(smem-base+ptr)[4]),
// __half2float(reinterpret_cast<half*>(smem-base+ptr)[5]),
// __half2float(reinterpret_cast<half*>(smem-base+ptr)[6]),
// __half2float(reinterpret_cast<half*>(smem-base+ptr)[7]));
// }
#else
#else
asm
volatile
(
"st.shared.v4.b32 [%0], {%1, %2, %3, %4};
\n
"
asm
volatile
(
"st.shared.v4.b32 [%0], {%1, %2, %3, %4};
\n
"
:
:
...
@@ -1190,7 +1237,7 @@ struct Allreduce {
...
@@ -1190,7 +1237,7 @@ struct Allreduce {
static
__device__
inline
T
run
(
T
x
,
Operator
&
op
)
{
static
__device__
inline
T
run
(
T
x
,
Operator
&
op
)
{
constexpr
int
OFFSET
=
THREADS
/
2
;
constexpr
int
OFFSET
=
THREADS
/
2
;
#if defined (__HIP_PLATFORM_HCC__)
#if defined (__HIP_PLATFORM_HCC__)
x
=
op
(
x
,
__shfl_xor
(
uint32_t
(
-
1
),
x
,
OFFSET
));
x
=
op
(
x
,
__shfl_xor
(
x
,
OFFSET
));
#else
#else
x
=
op
(
x
,
__shfl_xor_sync
(
uint32_t
(
-
1
),
x
,
OFFSET
));
x
=
op
(
x
,
__shfl_xor_sync
(
uint32_t
(
-
1
),
x
,
OFFSET
));
#endif
#endif
...
@@ -1221,8 +1268,8 @@ __device__ inline void quad_reduce(float (&dst)[M], float (&src)[M], Operator &
...
@@ -1221,8 +1268,8 @@ __device__ inline void quad_reduce(float (&dst)[M], float (&src)[M], Operator &
for
(
int
mi
=
0
;
mi
<
M
;
mi
++
){
for
(
int
mi
=
0
;
mi
<
M
;
mi
++
){
dst
[
mi
]
=
src
[
mi
];
dst
[
mi
]
=
src
[
mi
];
#if defined (__HIP_PLATFORM_HCC__)
#if defined (__HIP_PLATFORM_HCC__)
dst
[
mi
]
=
op
(
dst
[
mi
],
__shfl_down
(
dst
[
mi
],
2
));
dst
[
mi
]
=
op
(
dst
[
mi
],
__shfl_down
(
dst
[
mi
],
3
2
));
dst
[
mi
]
=
op
(
dst
[
mi
],
__shfl_down
(
dst
[
mi
],
1
));
dst
[
mi
]
=
op
(
dst
[
mi
],
__shfl_down
(
dst
[
mi
],
1
6
));
#else
#else
dst
[
mi
]
=
op
(
dst
[
mi
],
__shfl_down_sync
(
uint32_t
(
-
1
),
dst
[
mi
],
2
));
dst
[
mi
]
=
op
(
dst
[
mi
],
__shfl_down_sync
(
uint32_t
(
-
1
),
dst
[
mi
],
2
));
dst
[
mi
]
=
op
(
dst
[
mi
],
__shfl_down_sync
(
uint32_t
(
-
1
),
dst
[
mi
],
1
));
dst
[
mi
]
=
op
(
dst
[
mi
],
__shfl_down_sync
(
uint32_t
(
-
1
),
dst
[
mi
],
1
));
...
@@ -1267,4 +1314,34 @@ __device__ inline void quad_allreduce(float (&dst)[M], float2 (&src)[M], Operato
...
@@ -1267,4 +1314,34 @@ __device__ inline void quad_allreduce(float (&dst)[M], float2 (&src)[M], Operato
////////////////////////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////////////////////////
#if defined (__HIP_PLATFORM_HCC__)
template
<
int
THREADS
>
struct
Allreduce32
{
static_assert
(
THREADS
==
64
);
template
<
typename
T
,
typename
Operator
>
static
__device__
inline
T
run
(
T
x
,
Operator
&
op
)
{
return
op
(
x
,
__shfl_xor
(
x
,
32
));
}
};
template
<
typename
Operator
,
int
M
>
__device__
inline
void
binary_allreduce
(
float
(
&
dst
)[
M
],
float
(
&
src
)[
M
],
Operator
&
op
)
{
#pragma unroll
for
(
int
mi
=
0
;
mi
<
M
;
mi
++
){
dst
[
mi
]
=
src
[
mi
];
dst
[
mi
]
=
Allreduce32
<
64
>::
run
(
dst
[
mi
],
op
);
}
}
template
<
typename
Operator
,
int
M
>
__device__
inline
void
binary_allreduce
(
float
(
&
dst
)[
M
],
float2
(
&
src
)[
M
],
Operator
&
op
)
{
float
tmp
[
M
];
#pragma unroll
for
(
int
mi
=
0
;
mi
<
M
;
mi
++
){
tmp
[
mi
]
=
op
(
src
[
mi
].
x
,
src
[
mi
].
y
);
}
binary_allreduce
(
dst
,
tmp
,
op
);
}
#endif
}
// namespace fmha
}
// namespace fmha
apex/contrib/csrc/fmha/src/fmha_dgrad_fp16_128_64_kernel.sm80.cu
View file @
7eed2594
...
@@ -28,7 +28,11 @@
...
@@ -28,7 +28,11 @@
#include "fmha.h"
#include "fmha.h"
#include "fmha_dgrad_kernel_1xN_reload.h"
#include "fmha_dgrad_kernel_1xN_reload.h"
#if defined(__HIP_PLATFORM_HCC__)
using
Kernel_traits
=
FMHA_kernel_traits
<
128
,
64
,
16
,
1
,
2
,
0x08u
>
;
#else
using
Kernel_traits
=
FMHA_kernel_traits
<
128
,
64
,
16
,
1
,
4
,
0x08u
>
;
using
Kernel_traits
=
FMHA_kernel_traits
<
128
,
64
,
16
,
1
,
4
,
0x08u
>
;
#endif
extern
"C"
__global__
void
fmha_dgrad_fp16_128_64_sm80_kernel
(
Fused_multihead_attention_fprop_params
params
)
{
extern
"C"
__global__
void
fmha_dgrad_fp16_128_64_sm80_kernel
(
Fused_multihead_attention_fprop_params
params
)
{
fmha
::
compute_dv_1xN
<
Kernel_traits
>
(
params
);
fmha
::
compute_dv_1xN
<
Kernel_traits
>
(
params
);
...
...
apex/contrib/csrc/fmha/src/fmha_fprop_fp16_128_64_kernel.sm80.cu
View file @
7eed2594
...
@@ -28,7 +28,11 @@
...
@@ -28,7 +28,11 @@
#include "fmha.h"
#include "fmha.h"
#include "fmha_fprop_kernel_1xN.h"
#include "fmha_fprop_kernel_1xN.h"
#if defined(__HIP_PLATFORM_HCC__)
using
Kernel_traits
=
FMHA_kernel_traits
<
128
,
64
,
16
,
1
,
2
,
0x08u
>
;
#else
using
Kernel_traits
=
FMHA_kernel_traits
<
128
,
64
,
16
,
1
,
4
,
0x08u
>
;
using
Kernel_traits
=
FMHA_kernel_traits
<
128
,
64
,
16
,
1
,
4
,
0x08u
>
;
#endif
template
<
bool
Is_training
>
template
<
bool
Is_training
>
__global__
__global__
...
...
apex/contrib/csrc/fmha/src/fmha_fprop_fp16_256_64_kernel.sm80.cu
View file @
7eed2594
...
@@ -28,7 +28,11 @@
...
@@ -28,7 +28,11 @@
#include "fmha.h"
#include "fmha.h"
#include "fmha_fprop_kernel_1xN.h"
#include "fmha_fprop_kernel_1xN.h"
#if defined(__HIP_PLATFORM_HCC__)
using
Kernel_traits
=
FMHA_kernel_traits
<
256
,
64
,
16
,
1
,
2
,
0x08u
>
;
#else
using
Kernel_traits
=
FMHA_kernel_traits
<
256
,
64
,
16
,
1
,
4
,
0x08u
>
;
using
Kernel_traits
=
FMHA_kernel_traits
<
256
,
64
,
16
,
1
,
4
,
0x08u
>
;
#endif
template
<
bool
Is_training
>
template
<
bool
Is_training
>
__global__
__global__
...
...
apex/contrib/csrc/fmha/src/fmha_fprop_fp16_384_64_kernel.sm80.cu
View file @
7eed2594
...
@@ -28,7 +28,11 @@
...
@@ -28,7 +28,11 @@
#include "fmha.h"
#include "fmha.h"
#include "fmha_fprop_kernel_1xN.h"
#include "fmha_fprop_kernel_1xN.h"
#if defined(__HIP_PLATFORM_HCC__)
using
Kernel_traits
=
FMHA_kernel_traits
<
384
,
64
,
16
,
1
,
2
,
0x18u
>
;
#else
using
Kernel_traits
=
FMHA_kernel_traits
<
384
,
64
,
16
,
1
,
4
,
0x18u
>
;
using
Kernel_traits
=
FMHA_kernel_traits
<
384
,
64
,
16
,
1
,
4
,
0x18u
>
;
#endif
template
<
bool
Is_training
>
template
<
bool
Is_training
>
__global__
__global__
...
...
apex/contrib/csrc/fmha/src/fmha_fprop_fp16_512_64_kernel.sm80.cu
View file @
7eed2594
...
@@ -28,7 +28,11 @@
...
@@ -28,7 +28,11 @@
#include "fmha.h"
#include "fmha.h"
#include "fmha_fprop_kernel_1xN.h"
#include "fmha_fprop_kernel_1xN.h"
#if defined(__HIP_PLATFORM_HCC__)
using
Kernel_traits
=
FMHA_kernel_traits
<
512
,
64
,
16
,
1
,
4
,
0x00u
>
;
#else
using
Kernel_traits
=
FMHA_kernel_traits
<
512
,
64
,
16
,
1
,
8
,
0x00u
>
;
using
Kernel_traits
=
FMHA_kernel_traits
<
512
,
64
,
16
,
1
,
8
,
0x00u
>
;
#else
template
<
bool
Is_training
>
template
<
bool
Is_training
>
__global__
__global__
...
...
apex/contrib/csrc/fmha/src/fmha_fprop_kernel_1xN.h
View file @
7eed2594
...
@@ -111,11 +111,67 @@ struct Gemm_Q_K : public Gemm_Q_K_base<Kernel_traits> {
...
@@ -111,11 +111,67 @@ struct Gemm_Q_K : public Gemm_Q_K_base<Kernel_traits> {
Base
::
smem_q
.
load
(
Base
::
frag_q
[
ki
&
1
],
ki
);
Base
::
smem_q
.
load
(
Base
::
frag_q
[
ki
&
1
],
ki
);
// Do the math for the values already in registers.
// Do the math for the values already in registers.
fmha
::
gemm
(
acc_p
,
Base
::
frag_q
[(
ki
-
1
)
&
1
],
frag_k
[(
ki
-
1
)]);
fmha
::
gemm
(
acc_p
,
Base
::
frag_q
[(
ki
-
1
)
&
1
],
frag_k
[(
ki
-
1
)]);
// wangaq debug
// __syncthreads();
// if (blockIdx.x == 0) {
// for(int mi = 0; mi < Mma_tile_p::MMAS_M; ++mi) {
// printf("frag_q[%d] tid:%d ki:%d mi:%d %6.4f %6.4f %6.4f %6.4f\n", (ki - 1) & 1, threadIdx.x, ki, mi,
// Base::frag_q[(ki - 1) & 1][mi].template elt_as<float>(0),
// Base::frag_q[(ki - 1) & 1][mi].template elt_as<float>(1),
// Base::frag_q[(ki - 1) & 1][mi].template elt_as<float>(2),
// Base::frag_q[(ki - 1) & 1][mi].template elt_as<float>(3));
// }
// for (int ni = 0; ni < Mma_tile_p::MMAS_N; ++ni) {
// printf("frag_k[%d] tid:%d ki:%d ni:%d %6.4f %6.4f %6.4f %6.4f\n", (ki - 1) & 1, threadIdx.x, ki, ni,
// frag_k[(ki - 1) & 1][ni].template elt_as<float>(0),
// frag_k[(ki - 1) & 1][ni].template elt_as<float>(1),
// frag_k[(ki - 1) & 1][ni].template elt_as<float>(2),
// frag_k[(ki - 1) & 1][ni].template elt_as<float>(3));
// }
// for(int m = 0; m < M; ++m) {
// for (int n = 0; n < N; ++n) {
// printf("acc_p tid:%d ki:%d mi:%d ni:%d %6.4f %6.4f %6.4f %6.4f\n", threadIdx.x, ki, m, n,
// acc_p[m][n].template elt_as<float>(0),
// acc_p[m][n].template elt_as<float>(1),
// acc_p[m][n].template elt_as<float>(2),
// acc_p[m][n].template elt_as<float>(3));
// }
// }
// }
}
}
// Do the final stage of math.
// Do the final stage of math.
{
{
int
ki
=
Mma_tile_p
::
MMAS_K
;
int
ki
=
Mma_tile_p
::
MMAS_K
;
fmha
::
gemm
(
acc_p
,
Base
::
frag_q
[(
ki
-
1
)
&
1
],
frag_k
[(
ki
-
1
)]);
fmha
::
gemm
(
acc_p
,
Base
::
frag_q
[(
ki
-
1
)
&
1
],
frag_k
[(
ki
-
1
)]);
// wangaq debug
// __syncthreads();
// if (blockIdx.x == 0) {
// for(int mi = 0; mi < Mma_tile_p::MMAS_M; ++mi) {
// printf("frag_q[%d] tid:%d ki:%d mi:%d %6.4f %6.4f %6.4f %6.4f\n", (ki - 1) & 1, threadIdx.x, ki, mi,
// Base::frag_q[(ki - 1) & 1][mi].template elt_as<float>(0),
// Base::frag_q[(ki - 1) & 1][mi].template elt_as<float>(1),
// Base::frag_q[(ki - 1) & 1][mi].template elt_as<float>(2),
// Base::frag_q[(ki - 1) & 1][mi].template elt_as<float>(3));
// }
// for (int ni = 0; ni < Mma_tile_p::MMAS_N; ++ni) {
// printf("frag_k[%d] tid:%d ki:%d ni:%d %6.4f %6.4f %6.4f %6.4f\n", (ki - 1) & 1, threadIdx.x, ki, ni,
// frag_k[(ki - 1) & 1][ni].template elt_as<float>(0),
// frag_k[(ki - 1) & 1][ni].template elt_as<float>(1),
// frag_k[(ki - 1) & 1][ni].template elt_as<float>(2),
// frag_k[(ki - 1) & 1][ni].template elt_as<float>(3));
// }
// for(int m = 0; m < M; ++m) {
// for (int n = 0; n < N; ++n) {
// printf("acc_p tid:%d ki:%d mi:%d ni:%d %6.4f %6.4f %6.4f %6.4f\n", threadIdx.x, ki, m, n,
// acc_p[m][n].template elt_as<float>(0),
// acc_p[m][n].template elt_as<float>(1),
// acc_p[m][n].template elt_as<float>(2),
// acc_p[m][n].template elt_as<float>(3));
// }
// }
// }
}
}
}
}
...
@@ -188,6 +244,7 @@ constexpr size_t get_dynamic_smem_size(){
...
@@ -188,6 +244,7 @@ constexpr size_t get_dynamic_smem_size(){
template
<
typename
Kernel_traits
,
bool
Is_training
,
typename
Params
,
typename
Prng
>
template
<
typename
Kernel_traits
,
bool
Is_training
,
typename
Params
,
typename
Prng
>
inline
__device__
void
device_1xN_
(
const
Params
&
params
,
const
int
bidb
,
const
int
bidh
,
const
int
begin
,
const
int
steps
,
Prng
&
ph
)
{
inline
__device__
void
device_1xN_
(
const
Params
&
params
,
const
int
bidb
,
const
int
bidh
,
const
int
begin
,
const
int
steps
,
Prng
&
ph
)
{
// if (blockIdx.x == 0 && threadIdx.x == 0) printf("steps:%d\n", steps);
// The description of the CTA tile for the 1st batched GEMM.
// The description of the CTA tile for the 1st batched GEMM.
using
Cta_tile_p
=
typename
Kernel_traits
::
Cta_tile_p
;
using
Cta_tile_p
=
typename
Kernel_traits
::
Cta_tile_p
;
...
@@ -291,9 +348,43 @@ inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const i
...
@@ -291,9 +348,43 @@ inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const i
__syncthreads
();
__syncthreads
();
// wangaq debug
// if (blockIdx.x == 0 && tidx == 0) {
// __half * smem = reinterpret_cast<__half*>(smem_);
// printf("begin:%d q %d bytes smem:\n", begin, Gemm1::Smem_tile_q::BYTES_PER_TILE);
// for (int row = 0; row < Gemm1::Smem_tile_q::BYTES_PER_TILE / 2 / 8; ++row) {
// printf("row:%d ", row);
// for (int col = 0; col < 8; ++col) {
// printf("col:%d value:%6.4f\t", col, __half2float(smem[row*8+col]));
// }
// printf("\n");
// }
// printf("begin:%d v %d bytes smem:\n", begin, Smem_tile_v::BYTES_PER_TILE);
// smem = reinterpret_cast<__half*>(smem_v_);
// for (int row = 0; row < Smem_tile_v::BYTES_PER_TILE / 2 / 8; ++row) {
// printf("row:%d ", row);
// for (int col = 0; col < 8; ++col) {
// printf("col:%d value:%6.4f\t", col, __half2float(smem[row*8+col]));
// }
// printf("\n");
// }
// }
// Load the fragments for Q.
// Load the fragments for Q.
gemm_q_k
.
load_q
();
gemm_q_k
.
load_q
();
// wangaq debug
// __syncthreads();
// if (blockIdx.x == 0) {
// for(int mi = 0; mi < Mma_tile_p::MMAS_M; ++mi) {
// printf("frag_q tid:%d mi:%d %6.4f %6.4f %6.4f %6.4f\n", tidx, mi,
// gemm_q_k.frag_q[0][mi].template elt_as<float>(0),
// gemm_q_k.frag_q[0][mi].template elt_as<float>(1),
// gemm_q_k.frag_q[0][mi].template elt_as<float>(2),
// gemm_q_k.frag_q[0][mi].template elt_as<float>(3));
// }
// }
// Load the fragments for V. We keep the data in registers during the entire kernel.
// Load the fragments for V. We keep the data in registers during the entire kernel.
typename
Smem_tile_v
::
Fragment
frag_v
[
Mma_tile_o
::
MMAS_K
][
Mma_tile_o
::
MMAS_N
];
typename
Smem_tile_v
::
Fragment
frag_v
[
Mma_tile_o
::
MMAS_K
][
Mma_tile_o
::
MMAS_N
];
#pragma unroll
#pragma unroll
...
@@ -311,11 +402,39 @@ inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const i
...
@@ -311,11 +402,39 @@ inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const i
// Make sure the data is in shared memory.
// Make sure the data is in shared memory.
__syncthreads
();
__syncthreads
();
// wangaq debug
// if (blockIdx.x == 0 && tidx == 0) {
// printf("begin:%d k %d bytes smem:\n", begin, Gemm1::Smem_tile_k::BYTES_PER_TILE);
// __half * smem = reinterpret_cast<__half*>(smem_ + Gemm1::Smem_tile_q::BYTES_PER_TILE);
// for (int row = 0; row < Gemm1::Smem_tile_k::BYTES_PER_TILE / 2 / 8; ++row) {
// printf("row:%d ", row);
// for (int col = 0; col < 8; ++col) {
// printf("col:%d value:%6.4f\t", col, __half2float(smem[row*8+col]));
// }
// printf("\n");
// }
// }
}
}
// Load the fragments for K.
// Load the fragments for K.
gemm_q_k
.
load_k
();
gemm_q_k
.
load_k
();
// wangaq debug
// __syncthreads();
// if (blockIdx.x == 0) {
// // Fragment_k frag_k[Mma_tile_p::MMAS_K][Mma_tile_p::MMAS_N];
// for(int ki = 0; ki < Mma_tile_p::MMAS_K; ++ki) {
// for (int ni = 0; ni < Mma_tile_p::MMAS_N; ++ni) {
// printf("frag_k tid:%d ki:%d ni:%d %6.4f %6.4f %6.4f %6.4f\n", tidx, ki, ni,
// gemm_q_k.frag_k[ki][ni].template elt_as<float>(0),
// gemm_q_k.frag_k[ki][ni].template elt_as<float>(1),
// gemm_q_k.frag_k[ki][ni].template elt_as<float>(2),
// gemm_q_k.frag_k[ki][ni].template elt_as<float>(3));
// }
// }
// }
// Create the object to do the softmax.
// Create the object to do the softmax.
Softmax
softmax
(
params
,
&
smem_
[
Gemm1
::
SMEM_OFFSET_O
+
Smem_tile_o
::
BYTES_PER_TILE
],
bidb
,
tidx
);
Softmax
softmax
(
params
,
&
smem_
[
Gemm1
::
SMEM_OFFSET_O
+
Smem_tile_o
::
BYTES_PER_TILE
],
bidb
,
tidx
);
...
@@ -330,6 +449,19 @@ inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const i
...
@@ -330,6 +449,19 @@ inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const i
// Do this part of P^T = (Q * K^T)^T.
// Do this part of P^T = (Q * K^T)^T.
gemm_q_k
(
acc_p
);
gemm_q_k
(
acc_p
);
// wangaq debug
// if (blockIdx.x == 0) {
// for(int m = 0; m < Mma_tile_p::MMAS_M; ++m) {
// for (int n = 0; n < Mma_tile_p::MMAS_N; ++n) {
// printf("acc_p steps:%d step:%d tid:%d mi:%d ni:%d %6.4f %6.4f %6.4f %6.4f\n", steps, l, threadIdx.x, m, n,
// acc_p[m][n].template elt_as<float>(0),
// acc_p[m][n].template elt_as<float>(1),
// acc_p[m][n].template elt_as<float>(2),
// acc_p[m][n].template elt_as<float>(3));
// }
// }
// }
// Trigger the load for the next Q values.
// Trigger the load for the next Q values.
if
(
l
<
steps
-
1
)
{
if
(
l
<
steps
-
1
)
{
gemm_q_k
.
smem_q
.
move_to_next_write_buffer
();
gemm_q_k
.
smem_q
.
move_to_next_write_buffer
();
...
@@ -351,18 +483,33 @@ inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const i
...
@@ -351,18 +483,33 @@ inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const i
__syncthreads
();
__syncthreads
();
}
}
// Compute the max.
// Compute the max.
#if defined(__HIP_PLATFORM_HCC__)
float
p_max
[
Mma_tile_p
::
MMAS_M
];
#else
float
p_max
[
Mma_tile_p
::
MMAS_M
*
2
];
float
p_max
[
Mma_tile_p
::
MMAS_M
*
2
];
#endif
//softmax.template reduce<fmha::Max_>(p_max);
//softmax.template reduce<fmha::Max_>(p_max);
softmax
.
reduce_max
(
p_max
);
softmax
.
reduce_max
(
p_max
);
// wangaq debug
// if (blockIdx.x == 0) {
// for (int mi = 0; mi < Mma_tile_p::MMAS_M; ++mi) {
// printf("tid:%d mi:%d p_max:%f\n", threadIdx.x, mi, p_max[mi]);
// }
// }
// Compute the exponential value.
// Compute the exponential value.
softmax
.
apply_exp
(
p_max
);
softmax
.
apply_exp
(
p_max
);
// Compute the sum.
// Compute the sum.
#if defined(__HIP_PLATFORM_HCC__)
float
p_sum
[
Mma_tile_p
::
MMAS_M
];
#else
float
p_sum
[
Mma_tile_p
::
MMAS_M
*
2
];
float
p_sum
[
Mma_tile_p
::
MMAS_M
*
2
];
#endif
softmax
.
reduce_sum
(
p_sum
);
softmax
.
reduce_sum
(
p_sum
);
// Finalize softmax on the accumulators of P
^T
.
// Finalize softmax on the accumulators of P.
softmax
.
scale
(
p_sum
);
softmax
.
scale
(
p_sum
);
using
Frag_p
=
fmha
::
Fragment_a
<
fmha
::
Row
>
;
using
Frag_p
=
fmha
::
Fragment_a
<
fmha
::
Row
>
;
...
@@ -371,12 +518,20 @@ inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const i
...
@@ -371,12 +518,20 @@ inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const i
auto
encode_dropout
=
[](
bool
keep
,
float
val
)
{
return
keep
?
val
:
-
val
;
};
auto
encode_dropout
=
[](
bool
keep
,
float
val
)
{
return
keep
?
val
:
-
val
;
};
#pragma unroll
#pragma unroll
for
(
int
mi
=
0
;
mi
<
Mma_tile_p
::
MMAS_M
;
mi
++
)
{
for
(
int
mi
=
0
;
mi
<
Mma_tile_p
::
MMAS_M
;
mi
++
)
{
#pragma unroll
for
(
int
ii
=
0
;
ii
<
2
;
ii
++
)
{
#pragma unroll
#pragma unroll
for
(
int
ni
=
0
;
ni
<
Mma_tile_p
::
MMAS_N
;
ni
++
)
{
for
(
int
ni
=
0
;
ni
<
Mma_tile_p
::
MMAS_N
;
ni
++
)
{
float4
tmp
=
uniform4
(
ph
());
float4
tmp
=
uniform4
(
ph
());
// We encode the dropout pattern in the sign bit of the non-negative softmax to distinguish from pre-existing zeros
// We encode the dropout pattern in the sign bit of the non-negative softmax to distinguish from pre-existing zeros
#if defined(__HIP_PLATFORM_HCC__)
softmax
.
elt_
[
mi
][
4
*
ni
+
0
]
=
encode_dropout
(
tmp
.
x
<=
params
.
p_dropout
,
softmax
.
elt_
[
mi
][
4
*
ni
+
0
]);
softmax
.
elt_
[
mi
][
4
*
ni
+
1
]
=
encode_dropout
(
tmp
.
y
<=
params
.
p_dropout
,
softmax
.
elt_
[
mi
][
4
*
ni
+
1
]);
softmax
.
elt_
[
mi
][
4
*
ni
+
2
]
=
encode_dropout
(
tmp
.
z
<=
params
.
p_dropout
,
softmax
.
elt_
[
mi
][
4
*
ni
+
2
]);
softmax
.
elt_
[
mi
][
4
*
ni
+
3
]
=
encode_dropout
(
tmp
.
w
<=
params
.
p_dropout
,
softmax
.
elt_
[
mi
][
4
*
ni
+
3
]);
#else
softmax
.
elt_
[
2
*
mi
+
ii
][
4
*
ni
+
0
]
=
softmax
.
elt_
[
2
*
mi
+
ii
][
4
*
ni
+
0
]
=
encode_dropout
(
tmp
.
x
<=
params
.
p_dropout
,
softmax
.
elt_
[
2
*
mi
+
ii
][
4
*
ni
+
0
]);
encode_dropout
(
tmp
.
x
<=
params
.
p_dropout
,
softmax
.
elt_
[
2
*
mi
+
ii
][
4
*
ni
+
0
]);
softmax
.
elt_
[
2
*
mi
+
ii
][
4
*
ni
+
1
]
=
softmax
.
elt_
[
2
*
mi
+
ii
][
4
*
ni
+
1
]
=
...
@@ -385,11 +540,34 @@ inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const i
...
@@ -385,11 +540,34 @@ inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const i
encode_dropout
(
tmp
.
z
<=
params
.
p_dropout
,
softmax
.
elt_
[
2
*
mi
+
ii
][
4
*
ni
+
2
]);
encode_dropout
(
tmp
.
z
<=
params
.
p_dropout
,
softmax
.
elt_
[
2
*
mi
+
ii
][
4
*
ni
+
2
]);
softmax
.
elt_
[
2
*
mi
+
ii
][
4
*
ni
+
3
]
=
softmax
.
elt_
[
2
*
mi
+
ii
][
4
*
ni
+
3
]
=
encode_dropout
(
tmp
.
w
<=
params
.
p_dropout
,
softmax
.
elt_
[
2
*
mi
+
ii
][
4
*
ni
+
3
]);
encode_dropout
(
tmp
.
w
<=
params
.
p_dropout
,
softmax
.
elt_
[
2
*
mi
+
ii
][
4
*
ni
+
3
]);
}
#endif
}
}
}
}
softmax
.
pack
(
frag_p
);
softmax
.
pack
(
frag_p
);
// wangaq debug
// __syncthreads();
// if (blockIdx.x == 0) {
// for (int ki = 0; ki < Mma_tile_o::MMAS_K; ++ki) {
// for (int mi = 0; mi < Mma_tile_o::MMAS_M; ++mi) {
// printf("frag_p tid:%d ki:%d mi:%d %6.4f %6.4f %6.4f %6.4f\n", tidx, ki, mi,
// frag_p[ki][mi].template elt_as<float>(0),
// frag_p[ki][mi].template elt_as<float>(1),
// frag_p[ki][mi].template elt_as<float>(2),
// frag_p[ki][mi].template elt_as<float>(3));
// }
// }
// }
gmem_s
.
store
(
frag_p
,
mask
);
gmem_s
.
store
(
frag_p
,
mask
);
// wangaq debug
// printf("begin:%d gmem s:\n", begin);
// __half * gmem = reinterpret_cast<__half*>(gmem_s.ptr_);
// for (int i = 0; i < Gmem_tile_s::LOOP_STRIDE_BYTES / 2 / 8; ++i) {
// printf("tid:%d row:%d ", threadIdx.x, i);
// for (int j = 0; j < 8; ++j) {
// printf("col:%d value:%d\t", j, gmem[i*8+j]);
// }
// }
gmem_s
.
move
();
gmem_s
.
move
();
}
else
{
}
else
{
softmax
.
pack
(
frag_p
);
softmax
.
pack
(
frag_p
);
...
@@ -407,8 +585,8 @@ inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const i
...
@@ -407,8 +585,8 @@ inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const i
#pragma unroll
#pragma unroll
for
(
int
ii
=
0
;
ii
<
Frag_p
::
NUM_REGS
;
ii
++
)
{
for
(
int
ii
=
0
;
ii
<
Frag_p
::
NUM_REGS
;
ii
++
)
{
//"Apply" the dropout.
//"Apply" the dropout.
frag_p
[
ki
][
mi
].
reg
(
ii
)
=
fmha
::
h
mul
2
(
frag_p
[
ki
][
mi
].
reg
(
ii
),
params
.
scale_dropout
);
frag_p
[
ki
][
mi
].
reg
(
ii
)
=
fmha
::
f
mul
(
frag_p
[
ki
][
mi
].
reg
(
ii
),
params
.
scale_dropout
);
frag_p
[
ki
][
mi
].
reg
(
ii
)
=
fmha
::
h
relu
2
(
frag_p
[
ki
][
mi
].
reg
(
ii
));
frag_p
[
ki
][
mi
].
reg
(
ii
)
=
fmha
::
f
relu
(
frag_p
[
ki
][
mi
].
reg
(
ii
));
}
}
}
}
}
}
...
@@ -421,6 +599,34 @@ inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const i
...
@@ -421,6 +599,34 @@ inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const i
#pragma unroll
#pragma unroll
for
(
int
ki
=
0
;
ki
<
Mma_tile_o
::
MMAS_K
;
++
ki
)
{
for
(
int
ki
=
0
;
ki
<
Mma_tile_o
::
MMAS_K
;
++
ki
)
{
fmha
::
gemm
(
acc_o
,
frag_p
[
ki
],
frag_v
[
ki
]);
fmha
::
gemm
(
acc_o
,
frag_p
[
ki
],
frag_v
[
ki
]);
// wangaq debug
// __syncthreads();
// if (blockIdx.x == 0) {
// for (int mi = 0; mi < Mma_tile_o::MMAS_M; ++mi) {
// printf("frag_p tid:%d ki:%d mi:%d %6.4f %6.4f %6.4f %6.4f\n", tidx, ki, mi,
// frag_p[ki][mi].template elt_as<float>(0),
// frag_p[ki][mi].template elt_as<float>(1),
// frag_p[ki][mi].template elt_as<float>(2),
// frag_p[ki][mi].template elt_as<float>(3));
// }
// for (int ni = 0; ni < Mma_tile_o::MMAS_N; ++ni) {
// printf("frag_v tid:%d ki:%d ni:%d %6.4f %6.4f %6.4f %6.4f\n", tidx, ki, ni,
// frag_v[ki][ni].template elt_as<float>(0),
// frag_v[ki][ni].template elt_as<float>(1),
// frag_v[ki][ni].template elt_as<float>(2),
// frag_v[ki][ni].template elt_as<float>(3));
// }
// for (int mi = 0; mi < Mma_tile_o::MMAS_M; ++mi) {
// for (int ni = 0; ni < Mma_tile_o::MMAS_N; ++ni) {
// printf("acc_o tid:%d ki:%d mi:%d ni:%d %6.4f %6.4f %6.4f %6.4f\n", tidx, ki, mi, ni,
// acc_o[mi][ni].template elt_as<float>(0),
// acc_o[mi][ni].template elt_as<float>(1),
// acc_o[mi][ni].template elt_as<float>(2),
// acc_o[mi][ni].template elt_as<float>(3));
// }
// }
// }
}
}
// Loop over MMAS_M.
// Loop over MMAS_M.
...
@@ -429,6 +635,18 @@ inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const i
...
@@ -429,6 +635,18 @@ inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const i
// Swizzle the elements and do the final reduction.
// Swizzle the elements and do the final reduction.
smem_o
.
store
(
acc_o
,
ii
);
smem_o
.
store
(
acc_o
,
ii
);
// wangaq debug
// if (blockIdx.x == 0 && tidx == 0) {
// printf("ii:%d smem_o %d bytes smem:\n", ii, Smem_tile_o::BYTES_PER_TILE);
// float * smem_o = reinterpret_cast<float*>(&smem_[Gemm1::SMEM_OFFSET_O]);
// for (int row = 0; row < Smem_tile_o::BYTES_PER_TILE / 4 / 8; ++row) {
// printf("row:%d ", row);
// for (int col = 0; col < 8; ++col) {
// printf("col:%d value:%6.4f\t", col, smem_o[row*8+col]);
// }
// printf("\n");
// }
// }
// Make sure the data is in shared memory.
// Make sure the data is in shared memory.
__syncthreads
();
__syncthreads
();
...
@@ -436,6 +654,18 @@ inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const i
...
@@ -436,6 +654,18 @@ inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const i
// Load from shared memory.
// Load from shared memory.
uint4
out
[
Gmem_tile_o
::
STGS_PER_LOOP
];
uint4
out
[
Gmem_tile_o
::
STGS_PER_LOOP
];
smem_o
.
load
(
out
);
smem_o
.
load
(
out
);
// wangaq debug
// if (blockIdx.x == 0 && tidx == 0) {
// printf("ii:%d smem_o %d bytes smem:\n", ii, Smem_tile_o::BYTES_PER_TILE);
// float * smem_o = reinterpret_cast<float*>(&smem_[Gemm1::SMEM_OFFSET_O]);
// for (int row = 0; row < Smem_tile_o::BYTES_PER_TILE / 4 / 8; ++row) {
// printf("row:%d ", row);
// for (int col = 0; col < 8; ++col) {
// printf("col:%d value:%6.4f\t", col, smem_o[row*8+col]);
// }
// printf("\n");
// }
// }
// Make sure the data was read from shared memory.
// Make sure the data was read from shared memory.
if
(
ii
<
Gmem_tile_o
::
LOOPS
-
1
)
{
if
(
ii
<
Gmem_tile_o
::
LOOPS
-
1
)
{
...
@@ -472,10 +702,14 @@ inline __device__ void device_1xN(const Params ¶ms,
...
@@ -472,10 +702,14 @@ inline __device__ void device_1xN(const Params ¶ms,
const
int
tidx_global
=
blockIdx
.
x
*
gridDim
.
x
+
threadIdx
.
x
;
const
int
tidx_global
=
blockIdx
.
x
*
gridDim
.
x
+
threadIdx
.
x
;
auto
seeds
=
at
::
cuda
::
philox
::
unpack
(
params
.
philox_args
);
auto
seeds
=
at
::
cuda
::
philox
::
unpack
(
params
.
philox_args
);
Philox
ph
(
std
::
get
<
0
>
(
seeds
),
tidx_global
,
std
::
get
<
1
>
(
seeds
));
Philox
ph
(
std
::
get
<
0
>
(
seeds
),
tidx_global
,
std
::
get
<
1
>
(
seeds
));
// if (blockIdx.x == 0 && threadIdx.x == 0)
// printf("num_full_heads:%d num_main_groups:%d main_group_size:%d main_steps:%d, rest_steps:%d", num_full_heads, num_main_groups, main_group_size, main_steps, rest_steps);
for
(
int
it
=
0
;
it
<
num_full_heads
;
it
++
)
{
for
(
int
it
=
0
;
it
<
num_full_heads
;
it
++
)
{
const
int
bidx
=
it
*
gridDim
.
x
+
blockIdx
.
x
;
const
int
bidx
=
it
*
gridDim
.
x
+
blockIdx
.
x
;
const
int
bidh
=
bidx
%
params
.
h
;
const
int
bidh
=
bidx
%
params
.
h
;
const
int
bidb
=
bidx
/
params
.
h
;
const
int
bidb
=
bidx
/
params
.
h
;
// if (blockIdx.x == 0 && threadIdx.x == 0)
// printf("%s:%d N:%d M:%d steps:%d\n", __FILE__, __LINE__, Kernel_traits::Cta_tile_p::N, Kernel_traits::Cta_tile_p::M, STEPS);
fmha
::
device_1xN_
<
Kernel_traits
,
Is_training
>
(
params
,
bidb
,
bidh
,
0
,
STEPS
,
ph
);
fmha
::
device_1xN_
<
Kernel_traits
,
Is_training
>
(
params
,
bidb
,
bidh
,
0
,
STEPS
,
ph
);
__syncthreads
();
__syncthreads
();
}
}
...
@@ -490,6 +724,8 @@ inline __device__ void device_1xN(const Params ¶ms,
...
@@ -490,6 +724,8 @@ inline __device__ void device_1xN(const Params ¶ms,
const
int
bidh
=
(
head_offset
+
bidx
)
%
params
.
h
;
const
int
bidh
=
(
head_offset
+
bidx
)
%
params
.
h
;
const
int
bidb
=
(
head_offset
+
bidx
)
/
params
.
h
;
const
int
bidb
=
(
head_offset
+
bidx
)
/
params
.
h
;
const
int
offset
=
group
*
main_steps
;
const
int
offset
=
group
*
main_steps
;
// if (blockIdx.x == 0 && threadIdx.x == 0)
// printf("%s:%d N:%d M:%d steps:%d\n", __FILE__, __LINE__, Kernel_traits::Cta_tile_p::N, Kernel_traits::Cta_tile_p::M, STEPS);
fmha
::
device_1xN_
<
Kernel_traits
,
Is_training
>
(
params
,
bidb
,
bidh
,
offset
,
main_steps
,
ph
);
fmha
::
device_1xN_
<
Kernel_traits
,
Is_training
>
(
params
,
bidb
,
bidh
,
offset
,
main_steps
,
ph
);
}
else
{
}
else
{
if
(
rest_steps
==
0
)
return
;
if
(
rest_steps
==
0
)
return
;
...
@@ -501,6 +737,8 @@ inline __device__ void device_1xN(const Params ¶ms,
...
@@ -501,6 +737,8 @@ inline __device__ void device_1xN(const Params ¶ms,
for
(
int
it
=
head_offset
+
bidx
;
it
<
total_heads
;
it
+=
rest_ctas
)
{
for
(
int
it
=
head_offset
+
bidx
;
it
<
total_heads
;
it
+=
rest_ctas
)
{
const
int
bidh
=
it
%
params
.
h
;
const
int
bidh
=
it
%
params
.
h
;
const
int
bidb
=
it
/
params
.
h
;
const
int
bidb
=
it
/
params
.
h
;
// if (blockIdx.x == 0 && threadIdx.x == 0)
// printf("%s:%d N:%d M:%d steps:%d\n", __FILE__, __LINE__, Kernel_traits::Cta_tile_p::N, Kernel_traits::Cta_tile_p::M, STEPS);
fmha
::
device_1xN_
<
Kernel_traits
,
Is_training
>
(
params
,
bidb
,
bidh
,
offset
,
rest_steps
,
ph
);
fmha
::
device_1xN_
<
Kernel_traits
,
Is_training
>
(
params
,
bidb
,
bidh
,
offset
,
rest_steps
,
ph
);
__syncthreads
();
__syncthreads
();
}
}
...
...
apex/contrib/csrc/fmha/src/fmha_kernel.h
View file @
7eed2594
...
@@ -131,6 +131,8 @@ std::tuple<int , int, int, int, int, int> work_dist(const int total_ctas, const
...
@@ -131,6 +131,8 @@ std::tuple<int , int, int, int, int, int> work_dist(const int total_ctas, const
const
int
num_full_heads
=
heads_total
/
total_ctas
;
const
int
num_full_heads
=
heads_total
/
total_ctas
;
const
int
heads_last_wave
=
heads_total
%
total_ctas
;
const
int
heads_last_wave
=
heads_total
%
total_ctas
;
// printf("total_ctas:%d heads_total:%d num_full_heads:%d heads_last_wave:%d N:%d M:%d steps:%d\n",
// total_ctas, heads_total, num_full_heads, heads_last_wave, Kernel_traits::Cta_tile_p::N, Kernel_traits::Cta_tile_p::M, STEPS_PER_HEAD);
int
num_main_groups
=
0
;
int
num_main_groups
=
0
;
int
main_steps
=
0
;
int
main_steps
=
0
;
...
...
setup.py
View file @
7eed2594
...
@@ -512,15 +512,15 @@ if "--fmha" in sys.argv:
...
@@ -512,15 +512,15 @@ if "--fmha" in sys.argv:
CUDAExtension
(
name
=
'fmhalib'
,
CUDAExtension
(
name
=
'fmhalib'
,
sources
=
[
sources
=
[
'apex/contrib/csrc/fmha/fmha_api.cpp'
,
'apex/contrib/csrc/fmha/fmha_api.cpp'
,
'apex/contrib/csrc/fmha/src/fmha_noloop_reduce.cu'
,
#
'apex/contrib/csrc/fmha/src/fmha_noloop_reduce.cu',
'apex/contrib/csrc/fmha/src/fmha_fprop_fp16_128_64_kernel.sm80.cu'
,
'apex/contrib/csrc/fmha/src/fmha_fprop_fp16_128_64_kernel.sm80.cu'
,
'apex/contrib/csrc/fmha/src/fmha_fprop_fp16_256_64_kernel.sm80.cu'
,
'apex/contrib/csrc/fmha/src/fmha_fprop_fp16_256_64_kernel.sm80.cu'
,
'apex/contrib/csrc/fmha/src/fmha_fprop_fp16_384_64_kernel.sm80.cu'
,
'apex/contrib/csrc/fmha/src/fmha_fprop_fp16_384_64_kernel.sm80.cu'
,
'apex/contrib/csrc/fmha/src/fmha_fprop_fp16_512_64_kernel.sm80.cu'
,
#
'apex/contrib/csrc/fmha/src/fmha_fprop_fp16_512_64_kernel.sm80.cu',
'apex/contrib/csrc/fmha/src/fmha_dgrad_fp16_128_64_kernel.sm80.cu'
,
#
'apex/contrib/csrc/fmha/src/fmha_dgrad_fp16_128_64_kernel.sm80.cu',
'apex/contrib/csrc/fmha/src/fmha_dgrad_fp16_256_64_kernel.sm80.cu'
,
#
'apex/contrib/csrc/fmha/src/fmha_dgrad_fp16_256_64_kernel.sm80.cu',
'apex/contrib/csrc/fmha/src/fmha_dgrad_fp16_384_64_kernel.sm80.cu'
,
#
'apex/contrib/csrc/fmha/src/fmha_dgrad_fp16_384_64_kernel.sm80.cu',
'apex/contrib/csrc/fmha/src/fmha_dgrad_fp16_512_64_kernel.sm80.cu'
,
#
'apex/contrib/csrc/fmha/src/fmha_dgrad_fp16_512_64_kernel.sm80.cu',
],
],
extra_compile_args
=
{
'cxx'
:
[
'-O3'
,]
+
version_dependent_macros
+
generator_flag
,
extra_compile_args
=
{
'cxx'
:
[
'-O3'
,]
+
version_dependent_macros
+
generator_flag
,
'nvcc'
:
nvcc_args_mha
if
not
IS_ROCM_PYTORCH
else
hipcc_args_mha
},
'nvcc'
:
nvcc_args_mha
if
not
IS_ROCM_PYTORCH
else
hipcc_args_mha
},
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment