Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
8a6e65a3
Commit
8a6e65a3
authored
Feb 28, 2024
by
aska-0096
Browse files
update self-attention and cross-attention
parent
b62926dc
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
318 additions
and
179 deletions
+318
-179
example/32_batched_gemm_scale_softmax_gemm/cross_attention_forward_wmma_fp16.cpp
..._scale_softmax_gemm/cross_attention_forward_wmma_fp16.cpp
+22
-0
example/32_batched_gemm_scale_softmax_gemm/run_cross_attention_wmma.inc
...ched_gemm_scale_softmax_gemm/run_cross_attention_wmma.inc
+112
-72
example/32_batched_gemm_scale_softmax_gemm/run_self_attention_wmma.inc
...tched_gemm_scale_softmax_gemm/run_self_attention_wmma.inc
+100
-67
example/32_batched_gemm_scale_softmax_gemm/self_attention_forward_wmma_fp16.cpp
...m_scale_softmax_gemm/self_attention_forward_wmma_fp16.cpp
+84
-40
No files found.
example/32_batched_gemm_scale_softmax_gemm/cross_attention_forward_wmma_fp16.cpp
View file @
8a6e65a3
...
@@ -301,6 +301,28 @@ using DeviceMHAFactory =
...
@@ -301,6 +301,28 @@ using DeviceMHAFactory =
S
<
2
,
16
,
8
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
1
,
false
,
S
<
2
,
16
,
8
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
1
,
false
,
// CShuffleBlockTransfer MN
// CShuffleBlockTransfer MN
1
,
1
,
S
<
1
,
128
,
1
,
2
>
,
8
,
1
,
1
,
S
<
1
,
128
,
1
,
2
>
,
8
,
MaskingSpec
>
,
ck
::
tensor_operation
::
device
::
DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
<
NumDimG
,
NumDimM
,
NumDimN
,
NumDimK
,
NumDimO
,
ADataType
,
B0DataType
,
B1DataType
,
CDataType
,
Acc0BiasDataType
,
Acc0DataType
,
Acc1BiasDataType
,
Acc1DataType
,
CShuffleDataType
,
AElementOp
,
B0ElementOp
,
Acc0ElementOp
,
B1ElementOp
,
CElementOp
,
GemmSpec
,
TensorSpecA
,
TensorSpecB0
,
TensorSpecB1
,
TensorSpecC
,
1
,
256
,
// Gemm 0
128
,
64
,
48
,
8
,
4
,
// Gemm 1
48
,
64
,
8
,
16
,
16
,
16
,
// Per repeat = wave_m = wave_num, wave_n = 1
1
,
4
,
3
,
// ABlockTransfer MK -> K0 M K1
S
<
2
,
128
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
// B0BlockTransfer LK -> K0 L K1
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
true
,
// B1BlockTransfer NL -> L0 N L1
S
<
2
,
16
,
8
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
1
,
false
,
// CShuffleBlockTransfer MN
1
,
1
,
S
<
1
,
128
,
1
,
2
>
,
8
,
MaskingSpec
>
MaskingSpec
>
#endif
#endif
>
;
>
;
...
...
example/32_batched_gemm_scale_softmax_gemm/run_cross_attention_wmma.inc
View file @
8a6e65a3
...
@@ -9,20 +9,18 @@ int run(int argc, char* argv[])
...
@@ -9,20 +9,18 @@ int run(int argc, char* argv[])
// GEMM shape for A/B0/B1/C
// GEMM shape for A/B0/B1/C
// C_g_m_o = A_g_m_k * B0_g_k_n * B1_g_n_o
// C_g_m_o = A_g_m_k * B0_g_k_n * B1_g_n_o
ck
::
index_t
M
=
256
;
ck
::
index_t
q_sequence_length
=
256
;
ck
::
index_t
N
=
64
;
ck
::
index_t
kv_sequence_length
=
64
;
ck
::
index_t
K
=
80
;
ck
::
index_t
head_dim
=
80
;
ck
::
index_t
O
=
80
;
// Output shape C[batch_size, q_sequence_length, head_num, head_dim]. Batch dim, outer dim,
// Output shape C[G0, M, G1, O]. Batch dim, outer dim, inner dim must match GEMM shape
// inner dim must match GEMM shape C_g0_g1_m_o = reshape(C_g_m_o, [g0, g1, m, o]) C_g0_m_g1_o =
// C_g0_g1_m_o = reshape(C_g_m_o, [g0, g1, m, o])
// permute(C_g0_g1_m_o, [0, 2, 1, 3])
// C_g0_m_g1_o = permute(C_g0_g1_m_o, [0, 2, 1, 3])
ck
::
index_t
batch_size
=
2
;
ck
::
index_t
G0
=
2
;
ck
::
index_t
head_num
=
8
;
ck
::
index_t
G1
=
8
;
float
alpha
=
1
;
float
alpha
=
1
;
bool
input_permute
=
true
;
bool
input_permute
=
false
;
bool
output_permute
=
true
;
bool
output_permute
=
true
;
if
(
argc
==
1
)
if
(
argc
==
1
)
...
@@ -35,58 +33,85 @@ int run(int argc, char* argv[])
...
@@ -35,58 +33,85 @@ int run(int argc, char* argv[])
init_method
=
std
::
stoi
(
argv
[
2
]);
init_method
=
std
::
stoi
(
argv
[
2
]);
time_kernel
=
std
::
stoi
(
argv
[
3
]);
time_kernel
=
std
::
stoi
(
argv
[
3
]);
}
}
else
if
(
argc
==
1
3
)
else
if
(
argc
==
1
0
)
{
{
do_verification
=
std
::
stoi
(
argv
[
1
]);
do_verification
=
std
::
stoi
(
argv
[
1
]);
init_method
=
std
::
stoi
(
argv
[
2
]);
init_method
=
std
::
stoi
(
argv
[
2
]);
time_kernel
=
std
::
stoi
(
argv
[
3
]);
time_kernel
=
std
::
stoi
(
argv
[
3
]);
M
=
std
::
stoi
(
argv
[
4
]);
q_sequence_length
=
std
::
stoi
(
argv
[
4
]);
N
=
std
::
stoi
(
argv
[
5
]);
kv_sequence_length
=
std
::
stoi
(
argv
[
5
]);
K
=
std
::
stoi
(
argv
[
6
]);
head_dim
=
std
::
stoi
(
argv
[
6
]);
O
=
std
::
stoi
(
argv
[
7
]);
batch_size
=
std
::
stoi
(
argv
[
7
]);
G0
=
std
::
stoi
(
argv
[
8
]);
head_num
=
std
::
stoi
(
argv
[
8
]);
G1
=
std
::
stoi
(
argv
[
9
]);
alpha
=
std
::
stof
(
argv
[
10
]);
alpha
=
std
::
stof
(
argv
[
9
]);
input_permute
=
std
::
stoi
(
argv
[
11
]);
output_permute
=
std
::
stoi
(
argv
[
12
]);
}
}
else
else
{
{
printf
(
"arg1: verification (0=no, 1=yes)
\n
"
);
printf
(
"arg1: verification (0=no, 1=yes)
\n
"
);
printf
(
"arg2: initialization (0=no init, 1=integer value, 2=decimal value)
\n
"
);
printf
(
"arg2: initialization (0=no init, 1=integer value, 2=decimal value)
\n
"
);
printf
(
"arg3: time kernel (0=no, 1=yes)
\n
"
);
printf
(
"arg3: time kernel (0=no, 1=yes)
\n
"
);
printf
(
"arg4 to 11: M, N, K, O, G0, G1
\n
"
);
printf
(
printf
(
"arg10: scale (alpha)
\n
"
);
"arg4 to 8: q_sequence_length, kv_sequence_length, head_dim, batch_size, head_num
\n
"
);
printf
(
"arg
11 to 12: input / output permute
\n
"
);
printf
(
"arg
9: scale (alpha)
\n
"
);
exit
(
0
);
exit
(
0
);
}
}
std
::
vector
<
ck
::
index_t
>
a_gs_ms_ks_lengths
{
G0
,
G1
,
M
,
K
};
std
::
vector
<
ck
::
index_t
>
a_gs_ms_ks_lengths
{
batch_size
,
head_num
,
q_sequence_length
,
head_dim
};
std
::
vector
<
ck
::
index_t
>
a_gs_ms_ks_strides
=
std
::
vector
<
ck
::
index_t
>
a_gs_ms_ks_strides
=
input_permute
input_permute
?
std
::
vector
<
ck
::
index_t
>
{
q_sequence_length
*
head_num
*
head_dim
,
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
*
K
,
K
,
G1
*
K
,
1
}
// A layout [G0, M, G1, K]
head_dim
,
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
M
*
K
,
M
*
K
,
K
,
1
};
// A layout [G0, G1, M, K]
head_num
*
head_dim
,
1
}
std
::
vector
<
ck
::
index_t
>
b0_gs_ns_ks_lengths
{
G0
,
G1
,
N
,
K
};
// A layout [batch_size, q_sequence_length, head_num, head_dim]
:
std
::
vector
<
ck
::
index_t
>
{
head_num
*
q_sequence_length
*
head_dim
,
q_sequence_length
*
head_dim
,
head_dim
,
1
};
// A layout [batch_size, head_num, q_sequence_length, head_dim]
std
::
vector
<
ck
::
index_t
>
b0_gs_ns_ks_lengths
{
batch_size
,
head_num
,
kv_sequence_length
,
head_dim
};
std
::
vector
<
ck
::
index_t
>
b0_gs_ns_ks_strides
=
std
::
vector
<
ck
::
index_t
>
b0_gs_ns_ks_strides
=
input_permute
input_permute
?
std
::
vector
<
ck
::
index_t
>
{
kv_sequence_length
*
head_num
*
head_dim
,
?
std
::
vector
<
ck
::
index_t
>
{
N
*
G1
*
K
,
K
,
G1
*
K
,
1
}
// B0 layout [G0, N, G1, K]
head_dim
,
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
N
*
K
,
N
*
K
,
K
,
1
};
// B0 layout [G0, G1, N, K]
head_num
*
head_dim
,
1
}
std
::
vector
<
ck
::
index_t
>
b1_gs_os_ns_lengths
{
G0
,
G1
,
O
,
N
};
// B0 layout [batch_size, kv_sequence_length, head_num, head_dim]
:
std
::
vector
<
ck
::
index_t
>
{
head_num
*
kv_sequence_length
*
head_dim
,
kv_sequence_length
*
head_dim
,
head_dim
,
1
};
// B0 layout [batch_size, head_num, kv_sequence_length, head_dim]
std
::
vector
<
ck
::
index_t
>
b1_gs_os_ns_lengths
{
batch_size
,
head_num
,
head_dim
,
kv_sequence_length
};
std
::
vector
<
ck
::
index_t
>
b1_gs_os_ns_strides
=
std
::
vector
<
ck
::
index_t
>
b1_gs_os_ns_strides
=
input_permute
input_permute
?
std
::
vector
<
ck
::
index_t
>
{
N
*
G1
*
O
,
O
,
1
,
G1
*
O
}
// B1 layout [G0, N, G1, O]
?
std
::
vector
<
ck
::
index_t
>
{
kv_sequence_length
*
head_num
*
head_dim
,
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
N
*
O
,
N
*
O
,
1
,
O
};
// B1 layout [G0, G1, N, O]
head_dim
,
1
,
std
::
vector
<
ck
::
index_t
>
c_gs_ms_os_lengths
{
G0
,
G1
,
M
,
O
};
head_num
*
head_dim
}
// B1 layout [batch_size, kv_sequence_length, head_num, head_dim]
:
std
::
vector
<
ck
::
index_t
>
{
head_num
*
kv_sequence_length
*
head_dim
,
kv_sequence_length
*
head_dim
,
1
,
head_dim
};
// B1 layout [batch_size, head_num, kv_sequence_length, head_dim]
std
::
vector
<
ck
::
index_t
>
c_gs_ms_os_lengths
{
batch_size
,
head_num
,
q_sequence_length
,
head_dim
};
std
::
vector
<
ck
::
index_t
>
c_gs_ms_os_strides
=
std
::
vector
<
ck
::
index_t
>
c_gs_ms_os_strides
=
output_permute
output_permute
?
std
::
vector
<
ck
::
index_t
>
{
q_sequence_length
*
head_num
*
head_dim
,
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
*
O
,
O
,
G1
*
O
,
1
}
// C layout [G0, M, G1, O]
head_dim
,
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
M
*
O
,
M
*
O
,
O
,
1
};
// C layout [G0, G1, M, O]
head_num
*
head_dim
,
1
}
// C layout [batch_size, q_sequence_length, head_num, head_dim]
:
std
::
vector
<
ck
::
index_t
>
{
head_num
*
q_sequence_length
*
head_dim
,
q_sequence_length
*
head_dim
,
head_dim
,
1
};
// C layout [batch_size, head_num, q_sequence_length, head_dim]
Tensor
<
ADataType
>
a_gs_ms_ks
(
a_gs_ms_ks_lengths
,
a_gs_ms_ks_strides
);
Tensor
<
ADataType
>
a_gs_ms_ks
(
a_gs_ms_ks_lengths
,
a_gs_ms_ks_strides
);
Tensor
<
B0DataType
>
b0_gs_ns_ks
(
b0_gs_ns_ks_lengths
,
b0_gs_ns_ks_strides
);
Tensor
<
B0DataType
>
b0_gs_ns_ks
(
b0_gs_ns_ks_lengths
,
b0_gs_ns_ks_strides
);
...
@@ -158,9 +183,14 @@ int run(int argc, char* argv[])
...
@@ -158,9 +183,14 @@ int run(int argc, char* argv[])
b1_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
B1DataType
>
{});
b1_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
B1DataType
>
{});
}
}
std
::
vector
<
ck
::
index_t
>
kv_gs_ns_ks_lengths
{
G0
,
G1
,
N
,
2
,
K
};
std
::
vector
<
ck
::
index_t
>
kv_gs_ns_ks_lengths
{
batch_size
,
head_num
,
kv_sequence_length
,
2
,
head_dim
};
std
::
vector
<
ck
::
index_t
>
kv_gs_ns_ks_strides
=
std
::
vector
<
ck
::
index_t
>
{
std
::
vector
<
ck
::
index_t
>
kv_gs_ns_ks_strides
=
std
::
vector
<
ck
::
index_t
>
{
N
*
G1
*
2
*
K
,
2
*
K
,
G1
*
2
*
K
,
K
,
1
};
// kv layout [G0, M, G1, 2, K]
kv_sequence_length
*
head_num
*
2
*
head_dim
,
2
*
head_dim
,
head_num
*
2
*
head_dim
,
head_dim
,
1
};
// kv layout [batch_size, q_sequence_length, head_num, 2, head_dim]
Tensor
<
ADataType
>
kv_gs_ns_ks
(
kv_gs_ns_ks_lengths
,
kv_gs_ns_ks_strides
);
Tensor
<
ADataType
>
kv_gs_ns_ks
(
kv_gs_ns_ks_lengths
,
kv_gs_ns_ks_strides
);
// merge kv into a packed pointer send to device
// merge kv into a packed pointer send to device
b0_gs_ns_ks
.
ForEach
(
b0_gs_ns_ks
.
ForEach
(
...
@@ -189,20 +219,20 @@ int run(int argc, char* argv[])
...
@@ -189,20 +219,20 @@ int run(int argc, char* argv[])
printf
(
"Verification: %s
\n
"
,
do_verification
?
"ON"
:
"OFF"
);
printf
(
"Verification: %s
\n
"
,
do_verification
?
"ON"
:
"OFF"
);
// TODO ANT: replace array with vector?
// TODO ANT: replace array with vector?
ck
::
static_for
<
0
,
std
::
tuple_size_v
<
DeviceMHAFactory
>
,
1
>
{}([
&
](
auto
i
)
->
void
{
ck
::
static_for
<
0
,
std
::
tuple_size_v
<
DeviceMHAFactory
>
,
1
>
{}([
&
](
auto
i
)
->
void
{
const
auto
device_
conv_
mha_instance
=
std
::
get
<
i
>
(
DeviceMHAFactory
{});
const
auto
device_mha_instance
=
std
::
get
<
i
>
(
DeviceMHAFactory
{});
using
DeviceMHAInstance
=
ck
::
remove_cvref_t
<
decltype
(
device_
conv_
mha_instance
)
>
;
using
DeviceMHAInstance
=
ck
::
remove_cvref_t
<
decltype
(
device_mha_instance
)
>
;
auto
gemm
=
DeviceMHAInstance
{};
auto
gemm
=
DeviceMHAInstance
{};
auto
invoker
=
gemm
.
MakeCrossAttnInvoker
();
auto
invoker
=
gemm
.
MakeCrossAttnInvoker
();
auto
argument
=
auto
argument
=
gemm
.
MakeCrossAttnArgument
(
static_cast
<
ADataType
*>
(
q_device_buf
.
GetDeviceBuffer
()),
gemm
.
MakeCrossAttnArgument
(
static_cast
<
ADataType
*>
(
q_device_buf
.
GetDeviceBuffer
()),
static_cast
<
B0DataType
*>
(
kv_device_buf
.
GetDeviceBuffer
()),
static_cast
<
B0DataType
*>
(
kv_device_buf
.
GetDeviceBuffer
()),
static_cast
<
CDataType
*>
(
c_device_buf
.
GetDeviceBuffer
()),
static_cast
<
CDataType
*>
(
c_device_buf
.
GetDeviceBuffer
()),
G0
,
batch_size
,
M
,
q_sequence_length
,
N
,
kv_sequence_length
,
G1
,
head_num
,
K
,
head_dim
,
alpha
);
alpha
);
// if(!gemm.IsSupportedArgument(argument))
// if(!gemm.IsSupportedArgument(argument))
...
@@ -212,13 +242,17 @@ int run(int argc, char* argv[])
...
@@ -212,13 +242,17 @@ int run(int argc, char* argv[])
// return 0;
// return 0;
// }
// }
ck
::
index_t
BatchCount
=
G0
*
G1
;
ck
::
index_t
BatchCount
=
batch_size
*
head_num
;
float
ave_time
=
invoker
.
Run
(
argument
,
StreamConfig
{
nullptr
,
time_kernel
});
float
ave_time
=
invoker
.
Run
(
argument
,
StreamConfig
{
nullptr
,
time_kernel
});
std
::
size_t
flop
=
(
size_t
(
M
)
*
N
*
K
*
2
+
size_t
(
M
)
*
N
*
O
*
2
)
*
BatchCount
;
std
::
size_t
flop
=
(
size_t
(
q_sequence_length
)
*
kv_sequence_length
*
head_dim
*
2
+
std
::
size_t
num_btype
=
(
sizeof
(
ADataType
)
*
M
*
K
+
sizeof
(
B0DataType
)
*
K
*
N
+
size_t
(
q_sequence_length
)
*
kv_sequence_length
*
head_dim
*
2
)
*
sizeof
(
B1DataType
)
*
N
*
O
+
sizeof
(
CDataType
)
*
M
*
O
)
*
BatchCount
;
std
::
size_t
num_btype
=
(
sizeof
(
ADataType
)
*
q_sequence_length
*
head_dim
+
sizeof
(
B0DataType
)
*
head_dim
*
kv_sequence_length
+
sizeof
(
B1DataType
)
*
kv_sequence_length
*
head_dim
+
sizeof
(
CDataType
)
*
q_sequence_length
*
head_dim
)
*
BatchCount
;
BatchCount
;
float
tflops
=
static_cast
<
float
>
(
flop
)
/
1.E9
/
ave_time
;
float
tflops
=
static_cast
<
float
>
(
flop
)
/
1.E9
/
ave_time
;
...
@@ -237,22 +271,26 @@ int run(int argc, char* argv[])
...
@@ -237,22 +271,26 @@ int run(int argc, char* argv[])
{
{
c_device_buf
.
FromDevice
(
c_gs_ms_os_device_result
.
mData
.
data
());
c_device_buf
.
FromDevice
(
c_gs_ms_os_device_result
.
mData
.
data
());
Tensor
<
ADataType
>
a_g_m_k
({
BatchCount
,
M
,
K
});
Tensor
<
ADataType
>
a_g_m_k
({
BatchCount
,
q_sequence_length
,
head_dim
});
Tensor
<
B0DataType
>
b0_g_k_n
({
BatchCount
,
K
,
N
});
Tensor
<
B0DataType
>
b0_g_k_n
({
BatchCount
,
head_dim
,
kv_sequence_length
});
Tensor
<
B1DataType
>
b1_g_n_o
({
BatchCount
,
N
,
O
});
Tensor
<
B1DataType
>
b1_g_n_o
({
BatchCount
,
kv_sequence_length
,
head_dim
});
Tensor
<
Acc0DataType
>
acc0_g_m_n
({
BatchCount
,
M
,
N
});
// scratch object after gemm0
Tensor
<
Acc0DataType
>
acc0_g_m_n
(
Tensor
<
ADataType
>
a1_g_m_n
({
BatchCount
,
M
,
N
});
// scratch object after softmax
{
BatchCount
,
q_sequence_length
,
kv_sequence_length
});
// scratch object after gemm0
Tensor
<
CDataType
>
c_g_m_o_host_result
({
BatchCount
,
M
,
O
});
// scratch object after gemm1
Tensor
<
ADataType
>
a1_g_m_n
({
BatchCount
,
q_sequence_length
,
kv_sequence_length
});
// scratch object after softmax
Tensor
<
CDataType
>
c_g_m_o_host_result
(
{
BatchCount
,
q_sequence_length
,
head_dim
});
// scratch object after gemm1
// permute
// permute
a_gs_ms_ks
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
a_gs_ms_ks
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
a_g_m_k
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
a_g_m_k
(
idx
[
0
]
*
head_num
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
});
});
b0_gs_ns_ks
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
b0_gs_ns_ks
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
b0_g_k_n
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
3
],
idx
[
2
])
=
self
(
idx
);
b0_g_k_n
(
idx
[
0
]
*
head_num
+
idx
[
1
],
idx
[
3
],
idx
[
2
])
=
self
(
idx
);
});
});
b1_gs_os_ns
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
b1_gs_os_ns
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
b1_g_n_o
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
3
],
idx
[
2
])
=
self
(
idx
);
b1_g_n_o
(
idx
[
0
]
*
head_num
+
idx
[
1
],
idx
[
3
],
idx
[
2
])
=
self
(
idx
);
});
});
// gemm 0
// gemm 0
...
@@ -264,7 +302,7 @@ int run(int argc, char* argv[])
...
@@ -264,7 +302,7 @@ int run(int argc, char* argv[])
ref_gemm0_invoker
.
Run
(
ref_gemm0_argument
);
ref_gemm0_invoker
.
Run
(
ref_gemm0_argument
);
// masking
// masking
const
auto
mask
=
typename
DeviceMHAInstance
::
C0MatrixMask
(
N
);
const
auto
mask
=
typename
DeviceMHAInstance
::
C0MatrixMask
(
kv_sequence_length
);
acc0_g_m_n
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
acc0_g_m_n
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
if
(
mask
.
IsMaskedElement
(
idx
[
1
],
idx
[
2
]))
if
(
mask
.
IsMaskedElement
(
idx
[
1
],
idx
[
2
]))
self
(
idx
)
=
-
ck
::
NumericLimits
<
float
>::
Infinity
();
self
(
idx
)
=
-
ck
::
NumericLimits
<
float
>::
Infinity
();
...
@@ -294,7 +332,7 @@ int run(int argc, char* argv[])
...
@@ -294,7 +332,7 @@ int run(int argc, char* argv[])
const
size_t
&
g0
=
idx
[
0
];
const
size_t
&
g0
=
idx
[
0
];
const
size_t
&
g1
=
idx
[
1
];
const
size_t
&
g1
=
idx
[
1
];
const
size_t
g
=
g0
*
G1
+
g1
;
const
size_t
g
=
g0
*
head_num
+
g1
;
self
(
idx
)
=
c_g_m_o_host_result
(
g
,
idx
[
2
],
idx
[
3
]);
self
(
idx
)
=
c_g_m_o_host_result
(
g
,
idx
[
2
],
idx
[
3
]);
});
});
...
@@ -330,8 +368,10 @@ int run(int argc, char* argv[])
...
@@ -330,8 +368,10 @@ int run(int argc, char* argv[])
std
::
cout
<<
"---------------------------------------------------------------------------------"
std
::
cout
<<
"---------------------------------------------------------------------------------"
"-----------"
"-----------"
<<
std
::
endl
;
<<
std
::
endl
;
std
::
cout
<<
"Problem Size: BatchCount: "
<<
G0
<<
", HeadNum: "
<<
G1
<<
", M: "
<<
M
std
::
cout
<<
"Problem Size: BatchCount: "
<<
batch_size
<<
", HeadNum: "
<<
head_num
<<
", N: "
<<
N
<<
", K: "
<<
K
<<
", O: "
<<
O
<<
std
::
endl
;
<<
", q_sequence_length: "
<<
q_sequence_length
<<
", kv_sequence_length: "
<<
kv_sequence_length
<<
", head_dim: "
<<
head_dim
<<
std
::
endl
;
std
::
cout
<<
"---------------------------------------------------------------------------------"
std
::
cout
<<
"---------------------------------------------------------------------------------"
"-----------"
"-----------"
<<
std
::
endl
;
<<
std
::
endl
;
...
...
example/32_batched_gemm_scale_softmax_gemm/run_self_attention_wmma.inc
View file @
8a6e65a3
...
@@ -9,20 +9,17 @@ int run(int argc, char* argv[])
...
@@ -9,20 +9,17 @@ int run(int argc, char* argv[])
// GEMM shape for A/B0/B1/C
// GEMM shape for A/B0/B1/C
// C_g_m_o = A_g_m_k * B0_g_k_n * B1_g_n_o
// C_g_m_o = A_g_m_k * B0_g_k_n * B1_g_n_o
ck
::
index_t
M
=
256
;
ck
::
index_t
sequence_length
=
256
;
ck
::
index_t
N
=
256
;
ck
::
index_t
head_dim
=
80
;
ck
::
index_t
K
=
80
;
ck
::
index_t
O
=
80
;
// Output shape C[
G0, M, G1, O
]. Batch dim, outer dim, inner
dim must match GEMM shape
// Output shape C[
batch_size, sequence_length, head_num, head_dim
]. Batch dim, outer dim, inner
// C_g0_g1_m_o = reshape(C_g_m_o, [g0, g1, m, o])
//
dim must match GEMM shape
C_g0_g1_m_o = reshape(C_g_m_o, [g0, g1, m, o])
C_g0_m_g1_o =
//
C_g0_m_g1_o =
permute(C_g0_g1_m_o, [0, 2, 1, 3])
// permute(C_g0_g1_m_o, [0, 2, 1, 3])
ck
::
index_t
G0
=
2
;
ck
::
index_t
batch_size
=
2
;
ck
::
index_t
G1
=
8
;
ck
::
index_t
head_num
=
8
;
float
alpha
=
1
;
float
alpha
=
1
;
bool
input_permute
=
true
;
bool
input_permute
=
false
;
bool
output_permute
=
true
;
bool
output_permute
=
true
;
if
(
argc
==
1
)
if
(
argc
==
1
)
...
@@ -35,58 +32,81 @@ int run(int argc, char* argv[])
...
@@ -35,58 +32,81 @@ int run(int argc, char* argv[])
init_method
=
std
::
stoi
(
argv
[
2
]);
init_method
=
std
::
stoi
(
argv
[
2
]);
time_kernel
=
std
::
stoi
(
argv
[
3
]);
time_kernel
=
std
::
stoi
(
argv
[
3
]);
}
}
else
if
(
argc
==
13
)
else
if
(
argc
==
9
)
{
{
do_verification
=
std
::
stoi
(
argv
[
1
]);
do_verification
=
std
::
stoi
(
argv
[
1
]);
init_method
=
std
::
stoi
(
argv
[
2
]);
init_method
=
std
::
stoi
(
argv
[
2
]);
time_kernel
=
std
::
stoi
(
argv
[
3
]);
time_kernel
=
std
::
stoi
(
argv
[
3
]);
M
=
std
::
stoi
(
argv
[
4
]);
sequence_length
=
std
::
stoi
(
argv
[
4
]);
N
=
std
::
stoi
(
argv
[
5
]);
head_dim
=
std
::
stoi
(
argv
[
5
]);
K
=
std
::
stoi
(
argv
[
6
]);
batch_size
=
std
::
stoi
(
argv
[
6
]);
O
=
std
::
stoi
(
argv
[
7
]);
head_num
=
std
::
stoi
(
argv
[
7
]);
G0
=
std
::
stoi
(
argv
[
8
]);
G1
=
std
::
stoi
(
argv
[
9
]);
alpha
=
std
::
stof
(
argv
[
10
]);
input_permute
=
std
::
stoi
(
argv
[
11
]);
alpha
=
std
::
stof
(
argv
[
8
]);
output_permute
=
std
::
stoi
(
argv
[
12
]);
}
}
else
else
{
{
printf
(
"arg1: verification (0=no, 1=yes)
\n
"
);
printf
(
"arg1: verification (0=no, 1=yes)
\n
"
);
printf
(
"arg2: initialization (0=no init, 1=integer value, 2=decimal value)
\n
"
);
printf
(
"arg2: initialization (0=no init, 1=integer value, 2=decimal value)
\n
"
);
printf
(
"arg3: time kernel (0=no, 1=yes)
\n
"
);
printf
(
"arg3: time kernel (0=no, 1=yes)
\n
"
);
printf
(
"arg4 to 11: M, N, K, O, G0, G1
\n
"
);
printf
(
"arg4 to 7: sequence_length, head_dim, batch_size, head_num
\n
"
);
printf
(
"arg10: scale (alpha)
\n
"
);
printf
(
"arg8: scale (alpha)
\n
"
);
printf
(
"arg11 to 12: input / output permute
\n
"
);
exit
(
0
);
exit
(
0
);
}
}
std
::
vector
<
ck
::
index_t
>
a_gs_ms_ks_lengths
{
G0
,
G1
,
M
,
K
};
std
::
vector
<
ck
::
index_t
>
a_gs_ms_ks_lengths
{
batch_size
,
head_num
,
sequence_length
,
head_dim
};
std
::
vector
<
ck
::
index_t
>
a_gs_ms_ks_strides
=
std
::
vector
<
ck
::
index_t
>
a_gs_ms_ks_strides
=
input_permute
input_permute
?
std
::
vector
<
ck
::
index_t
>
{
sequence_length
*
head_num
*
head_dim
,
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
*
K
,
K
,
G1
*
K
,
1
}
// A layout [G0, M, G1, K]
head_dim
,
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
M
*
K
,
M
*
K
,
K
,
1
};
// A layout [G0, G1, M, K]
head_num
*
head_dim
,
1
}
std
::
vector
<
ck
::
index_t
>
b0_gs_ns_ks_lengths
{
G0
,
G1
,
N
,
K
};
// A layout [batch_size, sequence_length, head_num, head_dim]
:
std
::
vector
<
ck
::
index_t
>
{
head_num
*
sequence_length
*
head_dim
,
sequence_length
*
head_dim
,
head_dim
,
1
};
// A layout [batch_size, head_num, sequence_length, head_dim]
std
::
vector
<
ck
::
index_t
>
b0_gs_ns_ks_lengths
{
batch_size
,
head_num
,
sequence_length
,
head_dim
};
std
::
vector
<
ck
::
index_t
>
b0_gs_ns_ks_strides
=
std
::
vector
<
ck
::
index_t
>
b0_gs_ns_ks_strides
=
input_permute
input_permute
?
std
::
vector
<
ck
::
index_t
>
{
sequence_length
*
head_num
*
head_dim
,
?
std
::
vector
<
ck
::
index_t
>
{
N
*
G1
*
K
,
K
,
G1
*
K
,
1
}
// B0 layout [G0, N, G1, K]
head_dim
,
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
N
*
K
,
N
*
K
,
K
,
1
};
// B0 layout [G0, G1, N, K]
head_num
*
head_dim
,
1
}
std
::
vector
<
ck
::
index_t
>
b1_gs_os_ns_lengths
{
G0
,
G1
,
O
,
N
};
// B0 layout [batch_size, sequence_length, head_num, head_dim]
:
std
::
vector
<
ck
::
index_t
>
{
head_num
*
sequence_length
*
head_dim
,
sequence_length
*
head_dim
,
head_dim
,
1
};
// B0 layout [batch_size, head_num, sequence_length, head_dim]
std
::
vector
<
ck
::
index_t
>
b1_gs_os_ns_lengths
{
batch_size
,
head_num
,
head_dim
,
sequence_length
};
std
::
vector
<
ck
::
index_t
>
b1_gs_os_ns_strides
=
std
::
vector
<
ck
::
index_t
>
b1_gs_os_ns_strides
=
input_permute
input_permute
?
std
::
vector
<
ck
::
index_t
>
{
N
*
G1
*
O
,
O
,
1
,
G1
*
O
}
// B1 layout [G0, N, G1, O]
?
std
::
vector
<
ck
::
index_t
>
{
sequence_length
*
head_num
*
head_dim
,
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
N
*
O
,
N
*
O
,
1
,
O
};
// B1 layout [G0, G1, N, O]
head_dim
,
1
,
std
::
vector
<
ck
::
index_t
>
c_gs_ms_os_lengths
{
G0
,
G1
,
M
,
O
};
head_num
*
head_dim
}
// B1 layout [batch_size, sequence_length, head_num, head_dim]
:
std
::
vector
<
ck
::
index_t
>
{
head_num
*
sequence_length
*
head_dim
,
sequence_length
*
head_dim
,
1
,
head_dim
};
// B1 layout [batch_size, head_num, sequence_length, head_dim]
std
::
vector
<
ck
::
index_t
>
c_gs_ms_os_lengths
{
batch_size
,
head_num
,
sequence_length
,
head_dim
};
std
::
vector
<
ck
::
index_t
>
c_gs_ms_os_strides
=
std
::
vector
<
ck
::
index_t
>
c_gs_ms_os_strides
=
output_permute
output_permute
?
std
::
vector
<
ck
::
index_t
>
{
sequence_length
*
head_num
*
head_dim
,
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
*
O
,
O
,
G1
*
O
,
1
}
// C layout [G0, M, G1, O]
head_dim
,
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
M
*
O
,
M
*
O
,
O
,
1
};
// C layout [G0, G1, M, O]
head_num
*
head_dim
,
1
}
// C layout [batch_size, sequence_length, head_num, head_dim]
:
std
::
vector
<
ck
::
index_t
>
{
head_num
*
sequence_length
*
head_dim
,
sequence_length
*
head_dim
,
head_dim
,
1
};
// C layout [batch_size, head_num, sequence_length, head_dim]
Tensor
<
ADataType
>
a_gs_ms_ks
(
a_gs_ms_ks_lengths
,
a_gs_ms_ks_strides
);
Tensor
<
ADataType
>
a_gs_ms_ks
(
a_gs_ms_ks_lengths
,
a_gs_ms_ks_strides
);
Tensor
<
B0DataType
>
b0_gs_ns_ks
(
b0_gs_ns_ks_lengths
,
b0_gs_ns_ks_strides
);
Tensor
<
B0DataType
>
b0_gs_ns_ks
(
b0_gs_ns_ks_lengths
,
b0_gs_ns_ks_strides
);
...
@@ -158,9 +178,14 @@ int run(int argc, char* argv[])
...
@@ -158,9 +178,14 @@ int run(int argc, char* argv[])
b1_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
B1DataType
>
{});
b1_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
B1DataType
>
{});
}
}
std
::
vector
<
ck
::
index_t
>
qkv_gs_ms_ks_lengths
{
G0
,
G1
,
M
,
3
,
K
};
std
::
vector
<
ck
::
index_t
>
qkv_gs_ms_ks_lengths
{
batch_size
,
head_num
,
sequence_length
,
3
,
head_dim
};
std
::
vector
<
ck
::
index_t
>
qkv_gs_ms_ks_strides
=
std
::
vector
<
ck
::
index_t
>
{
std
::
vector
<
ck
::
index_t
>
qkv_gs_ms_ks_strides
=
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
*
3
*
K
,
3
*
K
,
G1
*
3
*
K
,
K
,
1
};
// qkv layout [G0, M, G1, 3, K]
sequence_length
*
head_num
*
3
*
head_dim
,
3
*
head_dim
,
head_num
*
3
*
head_dim
,
head_dim
,
1
};
// qkv layout [batch_size, sequence_length, head_num, 3, head_dim]
Tensor
<
ADataType
>
qkv_gs_ms_ks
(
qkv_gs_ms_ks_lengths
,
qkv_gs_ms_ks_strides
);
Tensor
<
ADataType
>
qkv_gs_ms_ks
(
qkv_gs_ms_ks_lengths
,
qkv_gs_ms_ks_strides
);
// merge qkv into a packed pointer send to device
// merge qkv into a packed pointer send to device
a_gs_ms_ks
.
ForEach
(
a_gs_ms_ks
.
ForEach
(
...
@@ -198,10 +223,10 @@ int run(int argc, char* argv[])
...
@@ -198,10 +223,10 @@ int run(int argc, char* argv[])
auto
argument
=
auto
argument
=
gemm
.
MakeSelfAttnArgument
(
static_cast
<
ADataType
*>
(
qkv_device_buf
.
GetDeviceBuffer
()),
gemm
.
MakeSelfAttnArgument
(
static_cast
<
ADataType
*>
(
qkv_device_buf
.
GetDeviceBuffer
()),
static_cast
<
CDataType
*>
(
c_device_buf
.
GetDeviceBuffer
()),
static_cast
<
CDataType
*>
(
c_device_buf
.
GetDeviceBuffer
()),
G0
,
batch_size
,
M
,
sequence_length
,
G1
,
head_num
,
K
,
head_dim
,
alpha
);
alpha
);
// if(!gemm.IsSupportedArgument(argument))
// if(!gemm.IsSupportedArgument(argument))
...
@@ -211,13 +236,17 @@ int run(int argc, char* argv[])
...
@@ -211,13 +236,17 @@ int run(int argc, char* argv[])
// return 0;
// return 0;
// }
// }
ck
::
index_t
BatchCount
=
G0
*
G1
;
ck
::
index_t
BatchCount
=
batch_size
*
head_num
;
float
ave_time
=
invoker
.
Run
(
argument
,
StreamConfig
{
nullptr
,
time_kernel
});
float
ave_time
=
invoker
.
Run
(
argument
,
StreamConfig
{
nullptr
,
time_kernel
});
std
::
size_t
flop
=
(
size_t
(
M
)
*
N
*
K
*
2
+
size_t
(
M
)
*
N
*
O
*
2
)
*
BatchCount
;
std
::
size_t
flop
=
(
size_t
(
sequence_length
)
*
sequence_length
*
head_dim
*
2
+
std
::
size_t
num_btype
=
(
sizeof
(
ADataType
)
*
M
*
K
+
sizeof
(
B0DataType
)
*
K
*
N
+
size_t
(
sequence_length
)
*
sequence_length
*
head_dim
*
2
)
*
sizeof
(
B1DataType
)
*
N
*
O
+
sizeof
(
CDataType
)
*
M
*
O
)
*
BatchCount
;
std
::
size_t
num_btype
=
(
sizeof
(
ADataType
)
*
sequence_length
*
head_dim
+
sizeof
(
B0DataType
)
*
head_dim
*
sequence_length
+
sizeof
(
B1DataType
)
*
sequence_length
*
head_dim
+
sizeof
(
CDataType
)
*
sequence_length
*
head_dim
)
*
BatchCount
;
BatchCount
;
float
tflops
=
static_cast
<
float
>
(
flop
)
/
1.E9
/
ave_time
;
float
tflops
=
static_cast
<
float
>
(
flop
)
/
1.E9
/
ave_time
;
...
@@ -236,22 +265,25 @@ int run(int argc, char* argv[])
...
@@ -236,22 +265,25 @@ int run(int argc, char* argv[])
{
{
c_device_buf
.
FromDevice
(
c_gs_ms_os_device_result
.
mData
.
data
());
c_device_buf
.
FromDevice
(
c_gs_ms_os_device_result
.
mData
.
data
());
Tensor
<
ADataType
>
a_g_m_k
({
BatchCount
,
M
,
K
});
Tensor
<
ADataType
>
a_g_m_k
({
BatchCount
,
sequence_length
,
head_dim
});
Tensor
<
B0DataType
>
b0_g_k_n
({
BatchCount
,
K
,
N
});
Tensor
<
B0DataType
>
b0_g_k_n
({
BatchCount
,
head_dim
,
sequence_length
});
Tensor
<
B1DataType
>
b1_g_n_o
({
BatchCount
,
N
,
O
});
Tensor
<
B1DataType
>
b1_g_n_o
({
BatchCount
,
sequence_length
,
head_dim
});
Tensor
<
Acc0DataType
>
acc0_g_m_n
({
BatchCount
,
M
,
N
});
// scratch object after gemm0
Tensor
<
Acc0DataType
>
acc0_g_m_n
(
Tensor
<
ADataType
>
a1_g_m_n
({
BatchCount
,
M
,
N
});
// scratch object after softmax
{
BatchCount
,
sequence_length
,
sequence_length
});
// scratch object after gemm0
Tensor
<
CDataType
>
c_g_m_o_host_result
({
BatchCount
,
M
,
O
});
// scratch object after gemm1
Tensor
<
ADataType
>
a1_g_m_n
(
{
BatchCount
,
sequence_length
,
sequence_length
});
// scratch object after softmax
Tensor
<
CDataType
>
c_g_m_o_host_result
(
{
BatchCount
,
sequence_length
,
head_dim
});
// scratch object after gemm1
// permute
// permute
a_gs_ms_ks
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
a_gs_ms_ks
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
a_g_m_k
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
a_g_m_k
(
idx
[
0
]
*
head_num
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
});
});
b0_gs_ns_ks
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
b0_gs_ns_ks
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
b0_g_k_n
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
3
],
idx
[
2
])
=
self
(
idx
);
b0_g_k_n
(
idx
[
0
]
*
head_num
+
idx
[
1
],
idx
[
3
],
idx
[
2
])
=
self
(
idx
);
});
});
b1_gs_os_ns
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
b1_gs_os_ns
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
b1_g_n_o
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
3
],
idx
[
2
])
=
self
(
idx
);
b1_g_n_o
(
idx
[
0
]
*
head_num
+
idx
[
1
],
idx
[
3
],
idx
[
2
])
=
self
(
idx
);
});
});
// gemm 0
// gemm 0
...
@@ -263,7 +295,7 @@ int run(int argc, char* argv[])
...
@@ -263,7 +295,7 @@ int run(int argc, char* argv[])
ref_gemm0_invoker
.
Run
(
ref_gemm0_argument
);
ref_gemm0_invoker
.
Run
(
ref_gemm0_argument
);
// masking
// masking
const
auto
mask
=
typename
DeviceMHAInstance
::
C0MatrixMask
(
N
);
const
auto
mask
=
typename
DeviceMHAInstance
::
C0MatrixMask
(
sequence_length
);
acc0_g_m_n
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
acc0_g_m_n
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
if
(
mask
.
IsMaskedElement
(
idx
[
1
],
idx
[
2
]))
if
(
mask
.
IsMaskedElement
(
idx
[
1
],
idx
[
2
]))
self
(
idx
)
=
-
ck
::
NumericLimits
<
float
>::
Infinity
();
self
(
idx
)
=
-
ck
::
NumericLimits
<
float
>::
Infinity
();
...
@@ -293,7 +325,7 @@ int run(int argc, char* argv[])
...
@@ -293,7 +325,7 @@ int run(int argc, char* argv[])
const
size_t
&
g0
=
idx
[
0
];
const
size_t
&
g0
=
idx
[
0
];
const
size_t
&
g1
=
idx
[
1
];
const
size_t
&
g1
=
idx
[
1
];
const
size_t
g
=
g0
*
G1
+
g1
;
const
size_t
g
=
g0
*
head_num
+
g1
;
self
(
idx
)
=
c_g_m_o_host_result
(
g
,
idx
[
2
],
idx
[
3
]);
self
(
idx
)
=
c_g_m_o_host_result
(
g
,
idx
[
2
],
idx
[
3
]);
});
});
...
@@ -329,8 +361,9 @@ int run(int argc, char* argv[])
...
@@ -329,8 +361,9 @@ int run(int argc, char* argv[])
std
::
cout
<<
"---------------------------------------------------------------------------------"
std
::
cout
<<
"---------------------------------------------------------------------------------"
"-----------"
"-----------"
<<
std
::
endl
;
<<
std
::
endl
;
std
::
cout
<<
"Problem Size: BatchCount: "
<<
G0
<<
", HeadNum: "
<<
G1
<<
", M: "
<<
M
std
::
cout
<<
"Problem Size: BatchCount: "
<<
batch_size
<<
", HeadNum: "
<<
head_num
<<
", N: "
<<
N
<<
", K: "
<<
K
<<
", O: "
<<
O
<<
std
::
endl
;
<<
", sequence_length: "
<<
sequence_length
<<
", head_dim: "
<<
head_dim
<<
std
::
endl
;
std
::
cout
<<
"---------------------------------------------------------------------------------"
std
::
cout
<<
"---------------------------------------------------------------------------------"
"-----------"
"-----------"
<<
std
::
endl
;
<<
std
::
endl
;
...
...
example/32_batched_gemm_scale_softmax_gemm/self_attention_forward_wmma_fp16.cpp
View file @
8a6e65a3
...
@@ -83,12 +83,34 @@ using DeviceMHAFactory =
...
@@ -83,12 +83,34 @@ using DeviceMHAFactory =
GemmSpec
,
TensorSpecA
,
TensorSpecB0
,
TensorSpecB1
,
TensorSpecC
,
1
,
GemmSpec
,
TensorSpecA
,
TensorSpecB0
,
TensorSpecB1
,
TensorSpecC
,
1
,
32
,
32
,
// Gemm 0
// Gemm 0
16
,
128
,
64
,
8
,
8
,
16
,
32
,
160
,
8
,
8
,
// Gemm 1
// Gemm 1
64
,
64
,
8
,
80
,
32
,
8
,
16
,
16
,
16
,
16
,
16
,
16
,
// Per repeat = wave_m = wave_num, wave_n = 1
1
,
2
,
5
,
// ABlockTransfer MK -> K0 M K1
S
<
2
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
// B0BlockTransfer LK -> K0 L K1
S
<
2
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
// B1BlockTransfer NL -> L0 N L1
S
<
2
,
2
,
8
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
8
,
1
,
false
,
// CShuffleBlockTransfer MN
1
,
1
,
S
<
1
,
16
,
1
,
2
>
,
8
,
MaskingSpec
>
,
ck
::
tensor_operation
::
device
::
DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
<
NumDimG
,
NumDimM
,
NumDimN
,
NumDimK
,
NumDimO
,
ADataType
,
B0DataType
,
B1DataType
,
CDataType
,
Acc0BiasDataType
,
Acc0DataType
,
Acc1BiasDataType
,
Acc1DataType
,
CShuffleDataType
,
AElementOp
,
B0ElementOp
,
Acc0ElementOp
,
B1ElementOp
,
CElementOp
,
GemmSpec
,
TensorSpecA
,
TensorSpecB0
,
TensorSpecB1
,
TensorSpecC
,
1
,
32
,
// Gemm 0
16
,
64
,
80
,
8
,
8
,
// Gemm 1
80
,
64
,
8
,
16
,
16
,
16
,
// Per repeat = wave_m = wave_num, wave_n = 1
// Per repeat = wave_m = wave_num, wave_n = 1
1
,
8
,
4
,
1
,
4
,
5
,
// ABlockTransfer MK -> K0 M K1
// ABlockTransfer MK -> K0 M K1
S
<
2
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
2
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
// B0BlockTransfer LK -> K0 L K1
// B0BlockTransfer LK -> K0 L K1
...
@@ -105,12 +127,12 @@ using DeviceMHAFactory =
...
@@ -105,12 +127,12 @@ using DeviceMHAFactory =
GemmSpec
,
TensorSpecA
,
TensorSpecB0
,
TensorSpecB1
,
TensorSpecC
,
1
,
GemmSpec
,
TensorSpecA
,
TensorSpecB0
,
TensorSpecB1
,
TensorSpecC
,
1
,
32
,
32
,
// Gemm 0
// Gemm 0
16
,
64
,
6
4
,
8
,
8
,
16
,
64
,
4
8
,
8
,
8
,
// Gemm 1
// Gemm 1
6
4
,
64
,
8
,
4
8
,
64
,
8
,
16
,
16
,
16
,
16
,
16
,
16
,
// Per repeat = wave_m = wave_num, wave_n = 1
// Per repeat = wave_m = wave_num, wave_n = 1
1
,
4
,
4
,
1
,
4
,
3
,
// ABlockTransfer MK -> K0 M K1
// ABlockTransfer MK -> K0 M K1
S
<
2
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
2
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
// B0BlockTransfer LK -> K0 L K1
// B0BlockTransfer LK -> K0 L K1
...
@@ -129,16 +151,16 @@ using DeviceMHAFactory =
...
@@ -129,16 +151,16 @@ using DeviceMHAFactory =
GemmSpec
,
TensorSpecA
,
TensorSpecB0
,
TensorSpecB1
,
TensorSpecC
,
1
,
GemmSpec
,
TensorSpecA
,
TensorSpecB0
,
TensorSpecB1
,
TensorSpecC
,
1
,
64
,
64
,
// Gemm 0
// Gemm 0
32
,
128
,
64
,
8
,
8
,
32
,
64
,
48
,
8
,
8
,
// Gemm 1
// Gemm 1
6
4
,
64
,
8
,
4
8
,
64
,
8
,
16
,
16
,
16
,
16
,
16
,
16
,
// Per repeat = wave_m = wave_num, wave_n = 1
// Per repeat = wave_m = wave_num, wave_n = 1
1
,
8
,
4
,
1
,
4
,
3
,
// ABlockTransfer MK -> K0 M K1
// ABlockTransfer MK -> K0 M K1
S
<
2
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
2
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
// B0BlockTransfer LK -> K0 L K1
// B0BlockTransfer LK -> K0 L K1
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
2
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
// B1BlockTransfer NL -> L0 N L1
// B1BlockTransfer NL -> L0 N L1
S
<
2
,
4
,
8
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
1
,
false
,
S
<
2
,
4
,
8
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
1
,
false
,
// CShuffleBlockTransfer MN
// CShuffleBlockTransfer MN
...
@@ -151,16 +173,38 @@ using DeviceMHAFactory =
...
@@ -151,16 +173,38 @@ using DeviceMHAFactory =
GemmSpec
,
TensorSpecA
,
TensorSpecB0
,
TensorSpecB1
,
TensorSpecC
,
1
,
GemmSpec
,
TensorSpecA
,
TensorSpecB0
,
TensorSpecB1
,
TensorSpecC
,
1
,
64
,
64
,
// Gemm 0
// Gemm 0
32
,
64
,
64
,
8
,
8
,
32
,
64
,
80
,
8
,
8
,
// Gemm 1
// Gemm 1
64
,
64
,
8
,
80
,
64
,
8
,
16
,
16
,
16
,
16
,
16
,
16
,
// Per repeat = wave_m = wave_num, wave_n = 1
// Per repeat = wave_m = wave_num, wave_n = 1
1
,
4
,
4
,
1
,
4
,
5
,
// ABlockTransfer MK -> K0 M K1
// ABlockTransfer MK -> K0 M K1
S
<
2
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
2
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
// B0BlockTransfer LK -> K0 L K1
// B0BlockTransfer LK -> K0 L K1
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
2
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
// B1BlockTransfer NL -> L0 N L1
S
<
2
,
4
,
8
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
1
,
false
,
// CShuffleBlockTransfer MN
1
,
1
,
S
<
1
,
32
,
1
,
2
>
,
8
,
MaskingSpec
>
,
ck
::
tensor_operation
::
device
::
DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
<
NumDimG
,
NumDimM
,
NumDimN
,
NumDimK
,
NumDimO
,
ADataType
,
B0DataType
,
B1DataType
,
CDataType
,
Acc0BiasDataType
,
Acc0DataType
,
Acc1BiasDataType
,
Acc1DataType
,
CShuffleDataType
,
AElementOp
,
B0ElementOp
,
Acc0ElementOp
,
B1ElementOp
,
CElementOp
,
GemmSpec
,
TensorSpecA
,
TensorSpecB0
,
TensorSpecB1
,
TensorSpecC
,
1
,
64
,
// Gemm 0
32
,
32
,
160
,
8
,
8
,
// Gemm 1
80
,
32
,
8
,
16
,
16
,
16
,
// Per repeat = wave_m = wave_num, wave_n = 1
1
,
2
,
5
,
// ABlockTransfer MK -> K0 M K1
S
<
2
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
// B0BlockTransfer LK -> K0 L K1
S
<
2
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
// B1BlockTransfer NL -> L0 N L1
// B1BlockTransfer NL -> L0 N L1
S
<
2
,
4
,
8
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
1
,
false
,
S
<
2
,
4
,
8
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
1
,
false
,
// CShuffleBlockTransfer MN
// CShuffleBlockTransfer MN
...
@@ -175,20 +219,20 @@ using DeviceMHAFactory =
...
@@ -175,20 +219,20 @@ using DeviceMHAFactory =
GemmSpec
,
TensorSpecA
,
TensorSpecB0
,
TensorSpecB1
,
TensorSpecC
,
1
,
GemmSpec
,
TensorSpecA
,
TensorSpecB0
,
TensorSpecB1
,
TensorSpecC
,
1
,
128
,
128
,
// Gemm 0
// Gemm 0
64
,
128
,
64
,
8
,
8
,
64
,
128
,
80
,
8
,
8
,
// Gemm 1
// Gemm 1
64
,
64
,
8
,
80
,
64
,
8
,
16
,
16
,
16
,
16
,
16
,
16
,
// Per repeat = wave_m = wave_num, wave_n = 1
// Per repeat = wave_m = wave_num, wave_n = 1
1
,
8
,
4
,
1
,
8
,
5
,
// ABlockTransfer MK -> K0 M K1
// ABlockTransfer MK -> K0 M K1
S
<
2
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
2
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
// B0BlockTransfer LK -> K0 L K1
// B0BlockTransfer LK -> K0 L K1
S
<
8
,
1
6
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
2
,
6
4
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
// B1BlockTransfer NL -> L0 N L1
// B1BlockTransfer NL -> L0 N L1
S
<
2
,
8
,
8
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
1
,
false
,
S
<
2
,
8
,
8
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
1
,
false
,
// CShuffleBlockTransfer MN
// CShuffleBlockTransfer MN
1
,
1
,
S
<
1
,
64
,
1
,
2
>
,
8
,
1
,
1
,
S
<
1
,
64
,
1
,
2
>
,
8
,
MaskingSpec
>
,
MaskingSpec
>
,
ck
::
tensor_operation
::
device
::
DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
<
ck
::
tensor_operation
::
device
::
DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
<
NumDimG
,
NumDimM
,
NumDimN
,
NumDimK
,
NumDimO
,
NumDimG
,
NumDimM
,
NumDimN
,
NumDimK
,
NumDimO
,
...
@@ -197,45 +241,45 @@ using DeviceMHAFactory =
...
@@ -197,45 +241,45 @@ using DeviceMHAFactory =
GemmSpec
,
TensorSpecA
,
TensorSpecB0
,
TensorSpecB1
,
TensorSpecC
,
1
,
GemmSpec
,
TensorSpecA
,
TensorSpecB0
,
TensorSpecB1
,
TensorSpecC
,
1
,
128
,
128
,
// Gemm 0
// Gemm 0
64
,
64
,
64
,
8
,
8
,
64
,
192
,
48
,
8
,
8
,
// Gemm 1
// Gemm 1
6
4
,
64
,
8
,
4
8
,
64
,
8
,
16
,
16
,
16
,
16
,
16
,
16
,
// Per repeat = wave_m = wave_num, wave_n = 1
// Per repeat = wave_m = wave_num, wave_n = 1
1
,
4
,
4
,
1
,
12
,
3
,
// ABlockTransfer MK -> K0 M K1
// ABlockTransfer MK -> K0 M K1
S
<
2
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
2
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
// B0BlockTransfer LK -> K0 L K1
// B0BlockTransfer LK -> K0 L K1
S
<
8
,
1
6
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
2
,
6
4
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
// B1BlockTransfer NL -> L0 N L1
// B1BlockTransfer NL -> L0 N L1
S
<
2
,
8
,
8
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
1
,
false
,
S
<
2
,
8
,
8
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
1
,
false
,
// CShuffleBlockTransfer MN
// CShuffleBlockTransfer MN
1
,
1
,
S
<
1
,
64
,
1
,
2
>
,
8
,
1
,
1
,
S
<
1
,
64
,
1
,
2
>
,
8
,
MaskingSpec
>
,
MaskingSpec
>
,
#endif
#ifdef CK_MHA_USE_WAVE_8
ck
::
tensor_operation
::
device
::
DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
<
ck
::
tensor_operation
::
device
::
DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
<
NumDimG
,
NumDimM
,
NumDimN
,
NumDimK
,
NumDimO
,
NumDimG
,
NumDimM
,
NumDimN
,
NumDimK
,
NumDimO
,
ADataType
,
B0DataType
,
B1DataType
,
CDataType
,
Acc0BiasDataType
,
Acc0DataType
,
Acc1BiasDataType
,
Acc1DataType
,
CShuffleDataType
,
ADataType
,
B0DataType
,
B1DataType
,
CDataType
,
Acc0BiasDataType
,
Acc0DataType
,
Acc1BiasDataType
,
Acc1DataType
,
CShuffleDataType
,
AElementOp
,
B0ElementOp
,
Acc0ElementOp
,
B1ElementOp
,
CElementOp
,
AElementOp
,
B0ElementOp
,
Acc0ElementOp
,
B1ElementOp
,
CElementOp
,
GemmSpec
,
TensorSpecA
,
TensorSpecB0
,
TensorSpecB1
,
TensorSpecC
,
1
,
GemmSpec
,
TensorSpecA
,
TensorSpecB0
,
TensorSpecB1
,
TensorSpecC
,
1
,
256
,
128
,
// Gemm 0
// Gemm 0
128
,
128
,
64
,
8
,
8
,
64
,
64
,
48
,
8
,
8
,
// Gemm 1
// Gemm 1
6
4
,
64
,
8
,
4
8
,
64
,
8
,
16
,
16
,
16
,
16
,
16
,
16
,
// Per repeat = wave_m = wave_num, wave_n = 1
// Per repeat = wave_m = wave_num, wave_n = 1
1
,
8
,
4
,
1
,
4
,
3
,
// ABlockTransfer MK -> K0 M K1
// ABlockTransfer MK -> K0 M K1
S
<
2
,
128
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
2
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
// B0BlockTransfer LK -> K0 L K1
// B0BlockTransfer LK -> K0 L K1
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
2
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
// B1BlockTransfer NL -> L0 N L1
// B1BlockTransfer NL -> L0 N L1
S
<
2
,
16
,
8
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
1
,
false
,
S
<
2
,
8
,
8
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
1
,
false
,
// CShuffleBlockTransfer MN
// CShuffleBlockTransfer MN
1
,
1
,
S
<
1
,
128
,
1
,
2
>
,
8
,
1
,
1
,
S
<
1
,
64
,
1
,
2
>
,
8
,
MaskingSpec
>
,
MaskingSpec
>
,
#endif
#ifdef CK_MHA_USE_WAVE_8
ck
::
tensor_operation
::
device
::
DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
<
ck
::
tensor_operation
::
device
::
DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
<
NumDimG
,
NumDimM
,
NumDimN
,
NumDimK
,
NumDimO
,
NumDimG
,
NumDimM
,
NumDimN
,
NumDimK
,
NumDimO
,
ADataType
,
B0DataType
,
B1DataType
,
CDataType
,
Acc0BiasDataType
,
Acc0DataType
,
Acc1BiasDataType
,
Acc1DataType
,
CShuffleDataType
,
ADataType
,
B0DataType
,
B1DataType
,
CDataType
,
Acc0BiasDataType
,
Acc0DataType
,
Acc1BiasDataType
,
Acc1DataType
,
CShuffleDataType
,
...
@@ -243,18 +287,18 @@ using DeviceMHAFactory =
...
@@ -243,18 +287,18 @@ using DeviceMHAFactory =
GemmSpec
,
TensorSpecA
,
TensorSpecB0
,
TensorSpecB1
,
TensorSpecC
,
1
,
GemmSpec
,
TensorSpecA
,
TensorSpecB0
,
TensorSpecB1
,
TensorSpecC
,
1
,
256
,
256
,
// Gemm 0
// Gemm 0
128
,
1
28
,
64
,
8
,
8
,
128
,
1
92
,
48
,
8
,
4
,
// Gemm 1
// Gemm 1
6
4
,
64
,
8
,
4
8
,
64
,
8
,
16
,
16
,
16
,
16
,
16
,
16
,
// Per repeat = wave_m = wave_num, wave_n = 1
// Per repeat = wave_m = wave_num, wave_n = 1
1
,
8
,
4
,
1
,
12
,
3
,
// ABlockTransfer MK -> K0 M K1
// ABlockTransfer MK -> K0 M K1
S
<
2
,
128
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
2
,
128
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
// B0BlockTransfer LK -> K0 L K1
// B0BlockTransfer LK -> K0 L K1
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
true
,
// B1BlockTransfer NL -> L0 N L1
// B1BlockTransfer NL -> L0 N L1
S
<
2
,
16
,
8
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
1
,
false
,
S
<
2
,
16
,
8
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
1
,
false
,
// CShuffleBlockTransfer MN
// CShuffleBlockTransfer MN
1
,
1
,
S
<
1
,
128
,
1
,
2
>
,
8
,
1
,
1
,
S
<
1
,
128
,
1
,
2
>
,
8
,
MaskingSpec
>
MaskingSpec
>
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment