Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
e87ddb0e
Commit
e87ddb0e
authored
Oct 26, 2023
by
letaoqin
Browse files
Merge branch 'mha-train-develop' into mha-train-develop-bias-shfl
parents
13129772
5ff2d646
Changes
38
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1537 additions
and
918 deletions
+1537
-918
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_backward_v2.cpp
..._softmax_gemm/batched_multihead_attention_backward_v2.cpp
+113
-71
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_backward_v3.cpp
..._softmax_gemm/batched_multihead_attention_backward_v3.cpp
+109
-67
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_forward_v2.cpp
...e_softmax_gemm/batched_multihead_attention_forward_v2.cpp
+7
-11
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_train_v2.cpp
...ale_softmax_gemm/batched_multihead_attention_train_v2.cpp
+121
-85
example/32_batched_gemm_scale_softmax_gemm/grouped_multihead_attention_backward_v2.cpp
..._softmax_gemm/grouped_multihead_attention_backward_v2.cpp
+108
-70
example/32_batched_gemm_scale_softmax_gemm/grouped_multihead_attention_backward_v3.cpp
..._softmax_gemm/grouped_multihead_attention_backward_v3.cpp
+109
-71
example/32_batched_gemm_scale_softmax_gemm/grouped_multihead_attention_forward_v2.cpp
...e_softmax_gemm/grouped_multihead_attention_forward_v2.cpp
+7
-11
example/32_batched_gemm_scale_softmax_gemm/grouped_multihead_attention_train_v2.cpp
...ale_softmax_gemm/grouped_multihead_attention_train_v2.cpp
+127
-89
example/32_batched_gemm_scale_softmax_gemm/run_batched_multihead_attention_forward.inc
..._softmax_gemm/run_batched_multihead_attention_forward.inc
+60
-50
example/32_batched_gemm_scale_softmax_gemm/run_grouped_multihead_attention_forward.inc
..._softmax_gemm/run_grouped_multihead_attention_forward.inc
+72
-57
example/52_flash_atten_bias/batched_multihead_attention_bias_backward_v2.cpp
...ten_bias/batched_multihead_attention_bias_backward_v2.cpp
+115
-75
example/52_flash_atten_bias/batched_multihead_attention_bias_forward_v2.cpp
...tten_bias/batched_multihead_attention_bias_forward_v2.cpp
+7
-11
example/52_flash_atten_bias/batched_multihead_attention_bias_forward_v2_zcheck.cpp
...as/batched_multihead_attention_bias_forward_v2_zcheck.cpp
+7
-11
example/52_flash_atten_bias/grouped_multihead_attention_bias_backward_v2.cpp
...ten_bias/grouped_multihead_attention_bias_backward_v2.cpp
+116
-78
example/52_flash_atten_bias/grouped_multihead_attention_bias_forward_v2.cpp
...tten_bias/grouped_multihead_attention_bias_forward_v2.cpp
+7
-11
example/52_flash_atten_bias/run_batched_multihead_attention_bias_forward_v2.inc
..._bias/run_batched_multihead_attention_bias_forward_v2.inc
+65
-55
example/52_flash_atten_bias/run_grouped_multihead_attention_bias_forward_v2.inc
..._bias/run_grouped_multihead_attention_bias_forward_v2.inc
+77
-63
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_bwd_xdl_cshuffle_qloop_light_v1.hpp
...pl/device_batched_mha_bwd_xdl_cshuffle_qloop_light_v1.hpp
+102
-10
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_bwd_xdl_cshuffle_qloop_light_v2.hpp
...pl/device_batched_mha_bwd_xdl_cshuffle_qloop_light_v2.hpp
+101
-8
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_bwd_xdl_cshuffle_qloop_v1.hpp
...ice/impl/device_batched_mha_bwd_xdl_cshuffle_qloop_v1.hpp
+107
-14
No files found.
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_backward_v2.cpp
View file @
e87ddb0e
...
...
@@ -269,14 +269,15 @@ int run(int argc, char* argv[])
// Overall QKV matrices shape
// y_g_m_o = Softmax(alpha * Q_g_m_k * K_g_k_n) * V_g_n_o
// y_g0_g1_m_o = reshape(y_g_m_o, [G0, G1, M, O])
// y_g0_m_g1_o = permute(y_g0_g1_m_o, [0, 2, 1, 3])
ck
::
index_t
M
=
512
;
ck
::
index_t
N
=
512
;
ck
::
index_t
K
=
DIM
;
ck
::
index_t
O
=
DIM
;
ck
::
index_t
G0
=
4
;
ck
::
index_t
G1
=
6
;
// y_g0_g1q_m_o = reshape(y_g_m_o, [G0, G1Q, M, O])
// y_g0_m_g1q_o = permute(y_g0_g1q_m_o, [0, 2, 1, 3])
ck
::
index_t
M
=
512
;
ck
::
index_t
N
=
512
;
ck
::
index_t
K
=
DIM
;
ck
::
index_t
O
=
DIM
;
ck
::
index_t
G0
=
4
;
ck
::
index_t
G1Q
=
6
;
// h_q
ck
::
index_t
G1KV
=
6
;
// h_kv
bool
input_permute
=
false
;
bool
output_permute
=
false
;
...
...
@@ -295,32 +296,33 @@ int run(int argc, char* argv[])
init_method
=
std
::
stoi
(
argv
[
2
]);
time_kernel
=
std
::
stoi
(
argv
[
3
]);
}
else
if
(
argc
==
1
3
)
else
if
(
argc
==
1
4
)
{
do_verification
=
std
::
stoi
(
argv
[
1
]);
init_method
=
std
::
stoi
(
argv
[
2
]);
time_kernel
=
std
::
stoi
(
argv
[
3
]);
M
=
std
::
stoi
(
argv
[
4
]);
N
=
std
::
stoi
(
argv
[
5
]);
K
=
std
::
stoi
(
argv
[
6
]);
O
=
std
::
stoi
(
argv
[
7
]);
G0
=
std
::
stoi
(
argv
[
8
]);
G1
=
std
::
stoi
(
argv
[
9
]);
M
=
std
::
stoi
(
argv
[
4
]);
N
=
std
::
stoi
(
argv
[
5
]);
K
=
std
::
stoi
(
argv
[
6
]);
O
=
std
::
stoi
(
argv
[
7
]);
G0
=
std
::
stoi
(
argv
[
8
]);
G1Q
=
std
::
stoi
(
argv
[
9
]);
G1KV
=
std
::
stoi
(
argv
[
10
]);
p_drop
=
std
::
stof
(
argv
[
1
0
]);
p_drop
=
std
::
stof
(
argv
[
1
1
]);
input_permute
=
std
::
stoi
(
argv
[
1
1
]);
output_permute
=
std
::
stoi
(
argv
[
1
2
]);
input_permute
=
std
::
stoi
(
argv
[
1
2
]);
output_permute
=
std
::
stoi
(
argv
[
1
3
]);
}
else
{
printf
(
"arg1: verification (0=no, 1=yes)
\n
"
);
printf
(
"arg2: initialization (0=no init, 1=integer value, 2=decimal value)
\n
"
);
printf
(
"arg3: time kernel (0=no, 1=yes)
\n
"
);
printf
(
"arg4 to 1
1
: M, N, K, O, G0, G1
\n
"
);
printf
(
"arg1
0
:
scale (alpha)
\n
"
);
printf
(
"arg1
1
to 1
2
: input / output permute
\n
"
);
printf
(
"arg4 to 1
0
: M, N, K, O, G0, G1
Q, G1KV
\n
"
);
printf
(
"arg1
1
:
p_drop
\n
"
);
printf
(
"arg1
2
to 1
3
: input / output permute
\n
"
);
exit
(
0
);
}
...
...
@@ -337,7 +339,8 @@ int run(int argc, char* argv[])
std
::
cout
<<
"K: "
<<
K
<<
std
::
endl
;
std
::
cout
<<
"O: "
<<
O
<<
std
::
endl
;
std
::
cout
<<
"G0: "
<<
G0
<<
std
::
endl
;
std
::
cout
<<
"G1: "
<<
G1
<<
std
::
endl
;
std
::
cout
<<
"G1Q: "
<<
G1Q
<<
std
::
endl
;
std
::
cout
<<
"G1KV: "
<<
G1KV
<<
std
::
endl
;
std
::
cout
<<
"alpha: "
<<
alpha
<<
std
::
endl
;
std
::
cout
<<
"input_permute: "
<<
input_permute
<<
std
::
endl
;
std
::
cout
<<
"output_permute: "
<<
output_permute
<<
std
::
endl
;
...
...
@@ -345,45 +348,57 @@ int run(int argc, char* argv[])
std
::
cout
<<
"seed: "
<<
seed
<<
std
::
endl
;
std
::
cout
<<
"offset: "
<<
offset
<<
std
::
endl
;
const
ck
::
index_t
BatchCount
=
G0
*
G1
;
const
ck
::
index_t
BatchCount
=
G0
*
G1
Q
;
std
::
vector
<
ck
::
index_t
>
q_gs_ms_ks_lengths
{
G0
,
G1
,
M
,
K
};
std
::
vector
<
ck
::
index_t
>
q_gs_ms_ks_lengths
{
G0
,
G1
Q
,
M
,
K
};
std
::
vector
<
ck
::
index_t
>
q_gs_ms_ks_strides
=
input_permute
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
*
K
,
K
,
G1
*
K
,
1
}
// Q layout [G0, M, G1, K]
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
M
*
K
,
M
*
K
,
K
,
1
};
// Q layout [G0, G1, M, K]
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
Q
*
K
,
K
,
G1
Q
*
K
,
1
}
// Q layout [G0, M, G1
Q
, K]
:
std
::
vector
<
ck
::
index_t
>
{
G1
Q
*
M
*
K
,
M
*
K
,
K
,
1
};
// Q layout [G0, G1
Q
, M, K]
std
::
vector
<
ck
::
index_t
>
k_gs_ns_ks_lengths
{
G0
,
G1
,
N
,
K
};
std
::
vector
<
ck
::
index_t
>
k_gs_ns_ks_lengths
{
G0
,
G1
KV
,
N
,
K
};
std
::
vector
<
ck
::
index_t
>
k_gs_ns_ks_strides
=
input_permute
?
std
::
vector
<
ck
::
index_t
>
{
N
*
G1
*
K
,
K
,
G1
*
K
,
1
}
// K layout [G0, N, G1, K]
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
N
*
K
,
N
*
K
,
K
,
1
};
// K layout [G0, G1, N, K]
?
std
::
vector
<
ck
::
index_t
>
{
N
*
G1
KV
*
K
,
K
,
G1
KV
*
K
,
1
}
// K layout [G0, N, G1
KV
, K]
:
std
::
vector
<
ck
::
index_t
>
{
G1
KV
*
N
*
K
,
N
*
K
,
K
,
1
};
// K layout [G0, G1
KV
, N, K]
std
::
vector
<
ck
::
index_t
>
v_gs_os_ns_lengths
{
G0
,
G1
,
O
,
N
};
std
::
vector
<
ck
::
index_t
>
v_gs_os_ns_lengths
{
G0
,
G1
KV
,
O
,
N
};
std
::
vector
<
ck
::
index_t
>
v_gs_os_ns_strides
=
input_permute
?
std
::
vector
<
ck
::
index_t
>
{
N
*
G1
*
O
,
O
,
1
,
G1
*
O
}
// V layout [G0, N, G1, O]
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
N
*
O
,
N
*
O
,
1
,
O
};
// V layout [G0, G1, N, O]
?
std
::
vector
<
ck
::
index_t
>
{
N
*
G1
KV
*
O
,
O
,
1
,
G1
KV
*
O
}
// V layout [G0, N, G1
KV
, O]
:
std
::
vector
<
ck
::
index_t
>
{
G1
KV
*
N
*
O
,
N
*
O
,
1
,
O
};
// V layout [G0, G1
KV
, N, O]
std
::
vector
<
ck
::
index_t
>
y_gs_ms_os_lengths
{
G0
,
G1
,
M
,
O
};
std
::
vector
<
ck
::
index_t
>
y_gs_ms_os_lengths
{
G0
,
G1
Q
,
M
,
O
};
std
::
vector
<
ck
::
index_t
>
y_gs_ms_os_strides
=
output_permute
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
*
O
,
O
,
G1
*
O
,
1
}
// Y layout [G0, M, G1, O]
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
M
*
O
,
M
*
O
,
O
,
1
};
// Y layout [G0, G1, M, O]
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
Q
*
O
,
O
,
G1
Q
*
O
,
1
}
// Y layout [G0, M, G1
Q
, O]
:
std
::
vector
<
ck
::
index_t
>
{
G1
Q
*
M
*
O
,
M
*
O
,
O
,
1
};
// Y layout [G0, G1
Q
, M, O]
std
::
vector
<
ck
::
index_t
>
z_gs_ms_ns_lengths
{
G0
,
G1
,
M
,
N
};
std
::
vector
<
ck
::
index_t
>
z_gs_ms_ns_lengths
{
G0
,
G1
Q
,
M
,
N
};
std
::
vector
<
ck
::
index_t
>
z_gs_ms_ns_strides
=
input_permute
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
*
N
,
N
,
G1
*
N
,
1
}
// Z layout [G0, M, G1, N]
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
M
*
N
,
M
*
N
,
N
,
1
};
// Z layout [G0, G1, M, N]
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1Q
*
N
,
N
,
G1Q
*
N
,
1
}
// Z layout [G0, M, G1Q, N]
:
std
::
vector
<
ck
::
index_t
>
{
G1Q
*
M
*
N
,
M
*
N
,
N
,
1
};
// Z layout [G0, G1Q, M, N]
std
::
vector
<
ck
::
index_t
>
kgrad_gs_ns_ks_lengths
{
G0
,
G1Q
,
N
,
K
};
std
::
vector
<
ck
::
index_t
>
kgrad_gs_ns_ks_strides
=
input_permute
?
std
::
vector
<
ck
::
index_t
>
{
N
*
G1Q
*
K
,
K
,
G1Q
*
K
,
1
}
// KGrad layout [G0, N, G1Q, K]
:
std
::
vector
<
ck
::
index_t
>
{
G1Q
*
N
*
K
,
N
*
K
,
K
,
1
};
// KGrad layout [G0, G1Q, N, K]
std
::
vector
<
ck
::
index_t
>
vgrad_gs_os_ns_lengths
{
G0
,
G1Q
,
O
,
N
};
std
::
vector
<
ck
::
index_t
>
vgrad_gs_os_ns_strides
=
input_permute
?
std
::
vector
<
ck
::
index_t
>
{
N
*
G1Q
*
O
,
O
,
1
,
G1Q
*
O
}
// VGrad layout [G0, N, G1Q, O]
:
std
::
vector
<
ck
::
index_t
>
{
G1Q
*
N
*
O
,
N
*
O
,
1
,
O
};
// VGrad layout [G0, G1Q, N, O]
// The softmax stat log-sum-exp (LSE) is used to speed up softmax calculation in backward pass
// Pi = exp(Si) / sum(exp(S0) + exp(S1) + ...)
// = exp(Si) / exp(log(sum(exp() + ...)))
// = exp(Si - log(sum(exp() + ...)))
// ^^^^^^^^^^^^^^^^^^^^^
// LSE
std
::
vector
<
ck
::
index_t
>
lse_gs_ms_lengths
{
G0
,
G1
,
M
};
std
::
vector
<
ck
::
index_t
>
lse_gs_ms_strides
{
G1
*
M
,
M
,
1
};
// LSE layout [G0, G1, M]
std
::
vector
<
ck
::
index_t
>
lse_gs_ms_lengths
{
G0
,
G1
Q
,
M
};
std
::
vector
<
ck
::
index_t
>
lse_gs_ms_strides
{
G1
Q
*
M
,
M
,
1
};
// LSE layout [G0, G1
Q
, M]
Tensor
<
InputDataType
>
q_gs_ms_ks
(
q_gs_ms_ks_lengths
,
q_gs_ms_ks_strides
);
Tensor
<
InputDataType
>
k_gs_ns_ks
(
k_gs_ns_ks_lengths
,
k_gs_ns_ks_strides
);
...
...
@@ -392,6 +407,8 @@ int run(int argc, char* argv[])
Tensor
<
InputDataType
>
y_gs_ms_os
(
y_gs_ms_os_lengths
,
y_gs_ms_os_strides
);
Tensor
<
InputDataType
>
ygrad_gs_ms_os
(
y_gs_ms_os_lengths
,
y_gs_ms_os_strides
);
Tensor
<
LSEDataType
>
lse_gs_ms
(
lse_gs_ms_lengths
,
lse_gs_ms_strides
);
Tensor
<
OutputDataType
>
kgrad_gs_ns_ks
(
kgrad_gs_ns_ks_lengths
,
kgrad_gs_ns_ks_strides
);
Tensor
<
OutputDataType
>
vgrad_gs_os_ns
(
vgrad_gs_os_ns_lengths
,
vgrad_gs_os_ns_strides
);
std
::
cout
<<
"q_gs_ms_ks: "
<<
q_gs_ms_ks
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"k_gs_ns_ks: "
<<
k_gs_ns_ks
.
mDesc
<<
std
::
endl
;
...
...
@@ -399,6 +416,8 @@ int run(int argc, char* argv[])
std
::
cout
<<
"v_gs_os_ns: "
<<
v_gs_os_ns
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"y_gs_ms_os: "
<<
y_gs_ms_os
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"lse_gs_ms_os: "
<<
lse_gs_ms
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"kgrad_gs_ns_ks: "
<<
kgrad_gs_ns_ks
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"vgrad_gs_os_ns: "
<<
vgrad_gs_os_ns
.
mDesc
<<
std
::
endl
;
z_gs_ms_ns
.
GenerateTensorValue
(
GeneratorTensor_1
<
InputDataType
>
{
0
});
switch
(
init_method
)
...
...
@@ -432,14 +451,14 @@ int run(int argc, char* argv[])
q_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_1
<
InputDataType
>
{
1
});
k_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
InputDataType
>
{});
v_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
InputDataType
>
{});
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_Sequential
<
2
>
{});
// dy[g0, g1, m, o]
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_Sequential
<
2
>
{});
// dy[g0, g1
q
, m, o]
// dO dot O = [0; 1; 2; ...]
break
;
case
6
:
q_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_1
<
InputDataType
>
{
1
});
k_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
InputDataType
>
{});
v_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
InputDataType
>
{});
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_Sequential
<
3
>
{});
// dy[g0, g1, m, o]
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_Sequential
<
3
>
{});
// dy[g0, g1
q
, m, o]
// assume mnko = 256
// P = softmax(QK) = 0.0039 * ones
// O = P V = 0.0039 * ones
...
...
@@ -452,7 +471,8 @@ int run(int argc, char* argv[])
q_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_1
<
InputDataType
>
{
1
});
k_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
InputDataType
>
{});
v_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
InputDataType
>
{});
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_1
<
InputDataType
>
{
1
});
// dy[g0, g1, m, o]
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_1
<
InputDataType
>
{
1
});
// dy[g0, g1q, m, o]
// assume mnko = 256
// P = softmax(QK) = 0.0039 * ones
// O = P V = 0.0039 * ones
...
...
@@ -474,11 +494,21 @@ int run(int argc, char* argv[])
Tensor
<
LSEDataType
>
lse_g_m
({
BatchCount
,
M
});
q_gs_ms_ks
.
ForEach
(
[
&
](
auto
&
self
,
auto
idx
)
{
q_g_m_k
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
});
k_gs_ns_ks
.
ForEach
(
[
&
](
auto
&
self
,
auto
idx
)
{
k_g_n_k
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
});
v_gs_os_ns
.
ForEach
(
[
&
](
auto
&
self
,
auto
idx
)
{
v_g_n_o
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
3
],
idx
[
2
])
=
self
(
idx
);
});
[
&
](
auto
&
self
,
auto
idx
)
{
q_g_m_k
(
idx
[
0
]
*
G1Q
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
});
k_g_n_k
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
const
size_t
&
g0
=
idx
[
0
]
/
G1Q
;
const
size_t
&
g1q
=
idx
[
0
]
%
G1Q
;
const
size_t
&
g1kv
=
g1q
/
(
G1Q
/
G1KV
);
self
(
idx
)
=
k_gs_ns_ks
(
g0
,
g1kv
,
idx
[
1
],
idx
[
2
]);
});
v_g_n_o
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
const
size_t
&
g0
=
idx
[
0
]
/
G1Q
;
const
size_t
&
g1q
=
idx
[
0
]
%
G1Q
;
const
size_t
&
g1kv
=
g1q
/
(
G1Q
/
G1KV
);
self
(
idx
)
=
v_gs_os_ns
(
g0
,
g1kv
,
idx
[
2
],
idx
[
1
]);
});
// qkv gradients have the same descriptor as with qkv
DeviceMem
q_device_buf
(
sizeof
(
InputDataType
)
*
q_gs_ms_ks
.
mDesc
.
GetElementSpaceSize
());
...
...
@@ -488,8 +518,8 @@ int run(int argc, char* argv[])
DeviceMem
y_device_buf
(
sizeof
(
InputDataType
)
*
y_gs_ms_os
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
lse_device_buf
(
sizeof
(
LSEDataType
)
*
lse_gs_ms
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
qgrad_device_buf
(
sizeof
(
OutputDataType
)
*
q_gs_ms_ks
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
kgrad_device_buf
(
sizeof
(
OutputDataType
)
*
k_gs_ns_ks
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
vgrad_device_buf
(
sizeof
(
OutputDataType
)
*
v_gs_os_ns
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
kgrad_device_buf
(
sizeof
(
OutputDataType
)
*
k
grad
_gs_ns_ks
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
vgrad_device_buf
(
sizeof
(
OutputDataType
)
*
v
grad
_gs_os_ns
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
ygrad_device_buf
(
sizeof
(
InputDataType
)
*
y_gs_ms_os
.
mDesc
.
GetElementSpaceSize
());
q_device_buf
.
ToDevice
(
q_gs_ms_ks
.
mData
.
data
());
...
...
@@ -513,8 +543,8 @@ int run(int argc, char* argv[])
static_cast
<
OutputDataType
*>
(
qgrad_device_buf
.
GetDeviceBuffer
()),
static_cast
<
OutputDataType
*>
(
kgrad_device_buf
.
GetDeviceBuffer
()),
static_cast
<
OutputDataType
*>
(
vgrad_device_buf
.
GetDeviceBuffer
()),
nullptr
,
//
p_acc0_bias;
nullptr
,
//
p_acc1_bias;
nullptr
,
// p_acc0_bias;
nullptr
,
// p_acc1_bias;
nullptr
,
nullptr
,
q_gs_ms_ks_lengths
,
...
...
@@ -528,6 +558,10 @@ int run(int argc, char* argv[])
y_gs_ms_os_lengths
,
y_gs_ms_os_strides
,
lse_gs_ms_lengths
,
kgrad_gs_ns_ks_lengths
,
kgrad_gs_ns_ks_strides
,
vgrad_gs_os_ns_lengths
,
vgrad_gs_os_ns_strides
,
{},
// std::array<std::vector<ck::index_t>, 1>{acc0_biases_gs_ms_ns_lengths},
{},
// std::array<std::vector<ck::index_t>, 1>{acc0_biases_gs_ms_ns_strides},
{},
// std::array<std::vector<ck::index_t>, 1>{acc1_biases_gs_ms_os_lengths},
...
...
@@ -560,8 +594,8 @@ int run(int argc, char* argv[])
static_cast
<
OutputDataType
*>
(
qgrad_device_buf
.
GetDeviceBuffer
()),
static_cast
<
OutputDataType
*>
(
kgrad_device_buf
.
GetDeviceBuffer
()),
static_cast
<
OutputDataType
*>
(
vgrad_device_buf
.
GetDeviceBuffer
()),
nullptr
,
//
p_acc0_bias;
nullptr
,
//
p_acc1_bias;
nullptr
,
// p_acc0_bias;
nullptr
,
// p_acc1_bias;
nullptr
,
nullptr
,
q_gs_ms_ks_lengths
,
...
...
@@ -575,6 +609,10 @@ int run(int argc, char* argv[])
y_gs_ms_os_lengths
,
y_gs_ms_os_strides
,
lse_gs_ms_lengths
,
kgrad_gs_ns_ks_lengths
,
kgrad_gs_ns_ks_strides
,
vgrad_gs_os_ns_lengths
,
vgrad_gs_os_ns_strides
,
{},
// std::array<std::vector<ck::index_t>, 1>{acc0_biases_gs_ms_ns_lengths},
{},
// std::array<std::vector<ck::index_t>, 1>{acc0_biases_gs_ms_ns_strides},
{},
// std::array<std::vector<ck::index_t>, 1>{acc1_biases_gs_ms_os_lengths},
...
...
@@ -614,7 +652,7 @@ int run(int argc, char* argv[])
// copy z matirx data form device
z_device_buf
.
FromDevice
(
z_gs_ms_ns
.
mData
.
data
());
z_gs_ms_ns
.
ForEach
(
[
&
](
auto
&
self
,
auto
idx
)
{
z_g_m_n
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
});
[
&
](
auto
&
self
,
auto
idx
)
{
z_g_m_n
(
idx
[
0
]
*
G1
Q
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
});
// std::cout << "z_g_m_n ref:\n" << z_g_m_n;
bool
pass
=
true
;
...
...
@@ -634,10 +672,10 @@ int run(int argc, char* argv[])
p_dropout_in_uint8_t
,
rp_dropout
);
y_gs_ms_os
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
self
(
idx
)
=
y_g_m_o
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
]);
self
(
idx
)
=
y_g_m_o
(
idx
[
0
]
*
G1
Q
+
idx
[
1
],
idx
[
2
],
idx
[
3
]);
});
lse_gs_ms
.
ForEach
(
[
&
](
auto
&
self
,
auto
idx
)
{
self
(
idx
)
=
lse_g_m
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
]);
});
[
&
](
auto
&
self
,
auto
idx
)
{
self
(
idx
)
=
lse_g_m
(
idx
[
0
]
*
G1
Q
+
idx
[
1
],
idx
[
2
]);
});
y_device_buf
.
ToDevice
(
y_gs_ms_os
.
mData
.
data
());
lse_device_buf
.
ToDevice
(
lse_gs_ms
.
mData
.
data
());
...
...
@@ -655,7 +693,7 @@ int run(int argc, char* argv[])
Tensor
<
InputDataType
>
ygrad_dot_y_g_m
({
BatchCount
,
M
});
ygrad_gs_ms_os
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
ygrad_g_m_o
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
ygrad_g_m_o
(
idx
[
0
]
*
G1
Q
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
});
#if PRINT_HOST
...
...
@@ -757,12 +795,16 @@ int run(int argc, char* argv[])
#endif
Tensor
<
OutputDataType
>
qgrad_gs_ms_ks_host_result
(
q_gs_ms_ks_lengths
,
q_gs_ms_ks_strides
);
Tensor
<
OutputDataType
>
kgrad_gs_ns_ks_host_result
(
k_gs_ns_ks_lengths
,
k_gs_ns_ks_strides
);
Tensor
<
OutputDataType
>
vgrad_gs_os_ns_host_result
(
v_gs_os_ns_lengths
,
v_gs_os_ns_strides
);
Tensor
<
OutputDataType
>
kgrad_gs_ns_ks_host_result
(
kgrad_gs_ns_ks_lengths
,
kgrad_gs_ns_ks_strides
);
Tensor
<
OutputDataType
>
vgrad_gs_os_ns_host_result
(
vgrad_gs_os_ns_lengths
,
vgrad_gs_os_ns_strides
);
Tensor
<
OutputDataType
>
qgrad_gs_ms_ks_device_result
(
q_gs_ms_ks_lengths
,
q_gs_ms_ks_strides
);
Tensor
<
OutputDataType
>
kgrad_gs_ns_ks_device_result
(
k_gs_ns_ks_lengths
,
k_gs_ns_ks_strides
);
Tensor
<
OutputDataType
>
vgrad_gs_os_ns_device_result
(
v_gs_os_ns_lengths
,
v_gs_os_ns_strides
);
Tensor
<
OutputDataType
>
kgrad_gs_ns_ks_device_result
(
kgrad_gs_ns_ks_lengths
,
kgrad_gs_ns_ks_strides
);
Tensor
<
OutputDataType
>
vgrad_gs_os_ns_device_result
(
vgrad_gs_os_ns_lengths
,
vgrad_gs_os_ns_strides
);
qgrad_device_buf
.
FromDevice
(
qgrad_gs_ms_ks_device_result
.
mData
.
data
());
kgrad_device_buf
.
FromDevice
(
kgrad_gs_ns_ks_device_result
.
mData
.
data
());
...
...
@@ -770,26 +812,26 @@ int run(int argc, char* argv[])
// permute
qgrad_gs_ms_ks_host_result
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
const
size_t
&
g0
=
idx
[
0
];
const
size_t
&
g1
=
idx
[
1
];
const
size_t
&
g0
=
idx
[
0
];
const
size_t
&
g1
q
=
idx
[
1
];
const
size_t
g
=
g0
*
G1
+
g1
;
const
size_t
g
=
g0
*
G1
Q
+
g1
q
;
self
(
idx
)
=
qgrad_g_m_k
(
g
,
idx
[
2
],
idx
[
3
]);
});
kgrad_gs_ns_ks_host_result
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
const
size_t
&
g0
=
idx
[
0
];
const
size_t
&
g1
=
idx
[
1
];
const
size_t
&
g0
=
idx
[
0
];
const
size_t
&
g1
q
=
idx
[
1
];
const
size_t
g
=
g0
*
G1
+
g1
;
const
size_t
g
=
g0
*
G1
Q
+
g1
q
;
self
(
idx
)
=
kgrad_g_n_k
(
g
,
idx
[
2
],
idx
[
3
]);
});
vgrad_gs_os_ns_host_result
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
const
size_t
&
g0
=
idx
[
0
];
const
size_t
&
g1
=
idx
[
1
];
const
size_t
&
g0
=
idx
[
0
];
const
size_t
&
g1
q
=
idx
[
1
];
const
size_t
g
=
g0
*
G1
+
g1
;
const
size_t
g
=
g0
*
G1
Q
+
g1
q
;
self
(
idx
)
=
vgrad_g_n_o
(
g
,
idx
[
3
],
idx
[
2
]);
});
...
...
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_backward_v3.cpp
View file @
e87ddb0e
...
...
@@ -270,14 +270,15 @@ int run(int argc, char* argv[])
// Overall QKV matrices shape
// y_g_m_o = Softmax(alpha * Q_g_m_k * K_g_k_n) * V_g_n_o
// y_g0_g1_m_o = reshape(y_g_m_o, [G0, G1, M, O])
// y_g0_m_g1_o = permute(y_g0_g1_m_o, [0, 2, 1, 3])
ck
::
index_t
M
=
512
;
ck
::
index_t
N
=
512
;
ck
::
index_t
K
=
DIM
;
ck
::
index_t
O
=
DIM
;
ck
::
index_t
G0
=
4
;
ck
::
index_t
G1
=
6
;
// y_g0_g1q_m_o = reshape(y_g_m_o, [G0, G1Q, M, O])
// y_g0_m_g1q_o = permute(y_g0_g1q_m_o, [0, 2, 1, 3])
ck
::
index_t
M
=
512
;
ck
::
index_t
N
=
512
;
ck
::
index_t
K
=
DIM
;
ck
::
index_t
O
=
DIM
;
ck
::
index_t
G0
=
4
;
ck
::
index_t
G1Q
=
6
;
// h_q
ck
::
index_t
G1KV
=
6
;
// h_kv
bool
input_permute
=
false
;
bool
output_permute
=
false
;
...
...
@@ -296,32 +297,33 @@ int run(int argc, char* argv[])
init_method
=
std
::
stoi
(
argv
[
2
]);
time_kernel
=
std
::
stoi
(
argv
[
3
]);
}
else
if
(
argc
==
1
3
)
else
if
(
argc
==
1
4
)
{
do_verification
=
std
::
stoi
(
argv
[
1
]);
init_method
=
std
::
stoi
(
argv
[
2
]);
time_kernel
=
std
::
stoi
(
argv
[
3
]);
M
=
std
::
stoi
(
argv
[
4
]);
N
=
std
::
stoi
(
argv
[
5
]);
K
=
std
::
stoi
(
argv
[
6
]);
O
=
std
::
stoi
(
argv
[
7
]);
G0
=
std
::
stoi
(
argv
[
8
]);
G1
=
std
::
stoi
(
argv
[
9
]);
M
=
std
::
stoi
(
argv
[
4
]);
N
=
std
::
stoi
(
argv
[
5
]);
K
=
std
::
stoi
(
argv
[
6
]);
O
=
std
::
stoi
(
argv
[
7
]);
G0
=
std
::
stoi
(
argv
[
8
]);
G1Q
=
std
::
stoi
(
argv
[
9
]);
G1KV
=
std
::
stoi
(
argv
[
10
]);
p_drop
=
std
::
stof
(
argv
[
1
0
]);
p_drop
=
std
::
stof
(
argv
[
1
1
]);
input_permute
=
std
::
stoi
(
argv
[
1
1
]);
output_permute
=
std
::
stoi
(
argv
[
1
2
]);
input_permute
=
std
::
stoi
(
argv
[
1
2
]);
output_permute
=
std
::
stoi
(
argv
[
1
3
]);
}
else
{
printf
(
"arg1: verification (0=no, 1=yes)
\n
"
);
printf
(
"arg2: initialization (0=no init, 1=integer value, 2=decimal value)
\n
"
);
printf
(
"arg3: time kernel (0=no, 1=yes)
\n
"
);
printf
(
"arg4 to 1
1
: M, N, K, O, G0, G1
\n
"
);
printf
(
"arg1
0
:
scale (alpha)
\n
"
);
printf
(
"arg1
1
to 1
2
: input / output permute
\n
"
);
printf
(
"arg4 to 1
0
: M, N, K, O, G0, G1
Q, G1KV
\n
"
);
printf
(
"arg1
1
:
p_drop
\n
"
);
printf
(
"arg1
2
to 1
3
: input / output permute
\n
"
);
exit
(
0
);
}
...
...
@@ -338,7 +340,8 @@ int run(int argc, char* argv[])
std
::
cout
<<
"K: "
<<
K
<<
std
::
endl
;
std
::
cout
<<
"O: "
<<
O
<<
std
::
endl
;
std
::
cout
<<
"G0: "
<<
G0
<<
std
::
endl
;
std
::
cout
<<
"G1: "
<<
G1
<<
std
::
endl
;
std
::
cout
<<
"G1Q: "
<<
G1Q
<<
std
::
endl
;
std
::
cout
<<
"G1KV: "
<<
G1KV
<<
std
::
endl
;
std
::
cout
<<
"alpha: "
<<
alpha
<<
std
::
endl
;
std
::
cout
<<
"input_permute: "
<<
input_permute
<<
std
::
endl
;
std
::
cout
<<
"output_permute: "
<<
output_permute
<<
std
::
endl
;
...
...
@@ -346,45 +349,57 @@ int run(int argc, char* argv[])
std
::
cout
<<
"seed: "
<<
seed
<<
std
::
endl
;
std
::
cout
<<
"offset: "
<<
offset
<<
std
::
endl
;
const
ck
::
index_t
BatchCount
=
G0
*
G1
;
const
ck
::
index_t
BatchCount
=
G0
*
G1
Q
;
std
::
vector
<
ck
::
index_t
>
q_gs_ms_ks_lengths
{
G0
,
G1
,
M
,
K
};
std
::
vector
<
ck
::
index_t
>
q_gs_ms_ks_lengths
{
G0
,
G1
Q
,
M
,
K
};
std
::
vector
<
ck
::
index_t
>
q_gs_ms_ks_strides
=
input_permute
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
*
K
,
K
,
G1
*
K
,
1
}
// Q layout [G0, M, G1, K]
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
M
*
K
,
M
*
K
,
K
,
1
};
// Q layout [G0, G1, M, K]
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
Q
*
K
,
K
,
G1
Q
*
K
,
1
}
// Q layout [G0, M, G1
Q
, K]
:
std
::
vector
<
ck
::
index_t
>
{
G1
Q
*
M
*
K
,
M
*
K
,
K
,
1
};
// Q layout [G0, G1
Q
, M, K]
std
::
vector
<
ck
::
index_t
>
k_gs_ns_ks_lengths
{
G0
,
G1
,
N
,
K
};
std
::
vector
<
ck
::
index_t
>
k_gs_ns_ks_lengths
{
G0
,
G1
KV
,
N
,
K
};
std
::
vector
<
ck
::
index_t
>
k_gs_ns_ks_strides
=
input_permute
?
std
::
vector
<
ck
::
index_t
>
{
N
*
G1
*
K
,
K
,
G1
*
K
,
1
}
// K layout [G0, N, G1, K]
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
N
*
K
,
N
*
K
,
K
,
1
};
// K layout [G0, G1, N, K]
?
std
::
vector
<
ck
::
index_t
>
{
N
*
G1
KV
*
K
,
K
,
G1
KV
*
K
,
1
}
// K layout [G0, N, G1
KV
, K]
:
std
::
vector
<
ck
::
index_t
>
{
G1
KV
*
N
*
K
,
N
*
K
,
K
,
1
};
// K layout [G0, G1
KV
, N, K]
std
::
vector
<
ck
::
index_t
>
v_gs_os_ns_lengths
{
G0
,
G1
,
O
,
N
};
std
::
vector
<
ck
::
index_t
>
v_gs_os_ns_lengths
{
G0
,
G1
KV
,
O
,
N
};
std
::
vector
<
ck
::
index_t
>
v_gs_os_ns_strides
=
input_permute
?
std
::
vector
<
ck
::
index_t
>
{
N
*
G1
*
O
,
O
,
1
,
G1
*
O
}
// V layout [G0, N, G1, O]
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
N
*
O
,
N
*
O
,
1
,
O
};
// V layout [G0, G1, N, O]
?
std
::
vector
<
ck
::
index_t
>
{
N
*
G1
KV
*
O
,
O
,
1
,
G1
KV
*
O
}
// V layout [G0, N, G1
KV
, O]
:
std
::
vector
<
ck
::
index_t
>
{
G1
KV
*
N
*
O
,
N
*
O
,
1
,
O
};
// V layout [G0, G1
KV
, N, O]
std
::
vector
<
ck
::
index_t
>
y_gs_ms_os_lengths
{
G0
,
G1
,
M
,
O
};
std
::
vector
<
ck
::
index_t
>
y_gs_ms_os_lengths
{
G0
,
G1
Q
,
M
,
O
};
std
::
vector
<
ck
::
index_t
>
y_gs_ms_os_strides
=
output_permute
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
*
O
,
O
,
G1
*
O
,
1
}
// Y layout [G0, M, G1, O]
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
M
*
O
,
M
*
O
,
O
,
1
};
// Y layout [G0, G1, M, O]
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
Q
*
O
,
O
,
G1
Q
*
O
,
1
}
// Y layout [G0, M, G1
Q
, O]
:
std
::
vector
<
ck
::
index_t
>
{
G1
Q
*
M
*
O
,
M
*
O
,
O
,
1
};
// Y layout [G0, G1
Q
, M, O]
std
::
vector
<
ck
::
index_t
>
z_gs_ms_ns_lengths
{
G0
,
G1
,
M
,
N
};
std
::
vector
<
ck
::
index_t
>
z_gs_ms_ns_lengths
{
G0
,
G1
Q
,
M
,
N
};
std
::
vector
<
ck
::
index_t
>
z_gs_ms_ns_strides
=
input_permute
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
*
N
,
N
,
G1
*
N
,
1
}
// Z layout [G0, M, G1, N]
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
M
*
N
,
M
*
N
,
N
,
1
};
// Z layout [G0, G1, M, N]
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1Q
*
N
,
N
,
G1Q
*
N
,
1
}
// Z layout [G0, M, G1Q, N]
:
std
::
vector
<
ck
::
index_t
>
{
G1Q
*
M
*
N
,
M
*
N
,
N
,
1
};
// Z layout [G0, G1Q, M, N]
std
::
vector
<
ck
::
index_t
>
kgrad_gs_ns_ks_lengths
{
G0
,
G1Q
,
N
,
K
};
std
::
vector
<
ck
::
index_t
>
kgrad_gs_ns_ks_strides
=
input_permute
?
std
::
vector
<
ck
::
index_t
>
{
N
*
G1Q
*
K
,
K
,
G1Q
*
K
,
1
}
// KGrad layout [G0, N, G1Q, K]
:
std
::
vector
<
ck
::
index_t
>
{
G1Q
*
N
*
K
,
N
*
K
,
K
,
1
};
// KGrad layout [G0, G1Q, N, K]
std
::
vector
<
ck
::
index_t
>
vgrad_gs_os_ns_lengths
{
G0
,
G1Q
,
O
,
N
};
std
::
vector
<
ck
::
index_t
>
vgrad_gs_os_ns_strides
=
input_permute
?
std
::
vector
<
ck
::
index_t
>
{
N
*
G1Q
*
O
,
O
,
1
,
G1Q
*
O
}
// VGrad layout [G0, N, G1Q, O]
:
std
::
vector
<
ck
::
index_t
>
{
G1Q
*
N
*
O
,
N
*
O
,
1
,
O
};
// VGrad layout [G0, G1Q, N, O]
// The softmax stat log-sum-exp (LSE) is used to speed up softmax calculation in backward pass
// Pi = exp(Si) / sum(exp(S0) + exp(S1) + ...)
// = exp(Si) / exp(log(sum(exp() + ...)))
// = exp(Si - log(sum(exp() + ...)))
// ^^^^^^^^^^^^^^^^^^^^^
// LSE
std
::
vector
<
ck
::
index_t
>
lse_gs_ms_lengths
{
G0
,
G1
,
M
};
std
::
vector
<
ck
::
index_t
>
lse_gs_ms_strides
{
G1
*
M
,
M
,
1
};
// LSE layout [G0, G1, M]
std
::
vector
<
ck
::
index_t
>
lse_gs_ms_lengths
{
G0
,
G1
Q
,
M
};
std
::
vector
<
ck
::
index_t
>
lse_gs_ms_strides
{
G1
Q
*
M
,
M
,
1
};
// LSE layout [G0, G1
Q
, M]
Tensor
<
InputDataType
>
q_gs_ms_ks
(
q_gs_ms_ks_lengths
,
q_gs_ms_ks_strides
);
Tensor
<
InputDataType
>
k_gs_ns_ks
(
k_gs_ns_ks_lengths
,
k_gs_ns_ks_strides
);
...
...
@@ -394,6 +409,8 @@ int run(int argc, char* argv[])
Tensor
<
InputDataType
>
ygrad_gs_ms_os
(
y_gs_ms_os_lengths
,
y_gs_ms_os_strides
);
Tensor
<
LSEDataType
>
lse_gs_ms
(
lse_gs_ms_lengths
,
lse_gs_ms_strides
);
Tensor
<
DDataType
>
d_gs_ms
(
lse_gs_ms_lengths
,
lse_gs_ms_strides
);
Tensor
<
OutputDataType
>
kgrad_gs_ns_ks
(
kgrad_gs_ns_ks_lengths
,
kgrad_gs_ns_ks_strides
);
Tensor
<
OutputDataType
>
vgrad_gs_os_ns
(
vgrad_gs_os_ns_lengths
,
vgrad_gs_os_ns_strides
);
std
::
cout
<<
"q_gs_ms_ks: "
<<
q_gs_ms_ks
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"k_gs_ns_ks: "
<<
k_gs_ns_ks
.
mDesc
<<
std
::
endl
;
...
...
@@ -402,6 +419,8 @@ int run(int argc, char* argv[])
std
::
cout
<<
"y_gs_ms_os: "
<<
y_gs_ms_os
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"lse_gs_ms_os: "
<<
lse_gs_ms
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"d_gs_ms_os: "
<<
d_gs_ms
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"kgrad_gs_ns_ks: "
<<
kgrad_gs_ns_ks
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"vgrad_gs_os_ns: "
<<
vgrad_gs_os_ns
.
mDesc
<<
std
::
endl
;
z_gs_ms_ns
.
GenerateTensorValue
(
GeneratorTensor_1
<
InputDataType
>
{
0
});
switch
(
init_method
)
...
...
@@ -435,14 +454,14 @@ int run(int argc, char* argv[])
q_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_1
<
InputDataType
>
{
1
});
k_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
InputDataType
>
{});
v_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
InputDataType
>
{});
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_Sequential
<
2
>
{});
// dy[g0, g1, m, o]
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_Sequential
<
2
>
{});
// dy[g0, g1
q
, m, o]
// dO dot O = [0; 1; 2; ...]
break
;
case
6
:
q_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_1
<
InputDataType
>
{
1
});
k_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
InputDataType
>
{});
v_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
InputDataType
>
{});
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_Sequential
<
3
>
{});
// dy[g0, g1, m, o]
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_Sequential
<
3
>
{});
// dy[g0, g1
q
, m, o]
// assume mnko = 256
// P = softmax(QK) = 0.0039 * ones
// O = P V = 0.0039 * ones
...
...
@@ -455,7 +474,8 @@ int run(int argc, char* argv[])
q_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_1
<
InputDataType
>
{
1
});
k_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
InputDataType
>
{});
v_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
InputDataType
>
{});
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_1
<
InputDataType
>
{
1
});
// dy[g0, g1, m, o]
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_1
<
InputDataType
>
{
1
});
// dy[g0, g1q, m, o]
// assume mnko = 256
// P = softmax(QK) = 0.0039 * ones
// O = P V = 0.0039 * ones
...
...
@@ -477,11 +497,21 @@ int run(int argc, char* argv[])
Tensor
<
LSEDataType
>
lse_g_m
({
BatchCount
,
M
});
q_gs_ms_ks
.
ForEach
(
[
&
](
auto
&
self
,
auto
idx
)
{
q_g_m_k
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
});
k_gs_ns_ks
.
ForEach
(
[
&
](
auto
&
self
,
auto
idx
)
{
k_g_n_k
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
});
v_gs_os_ns
.
ForEach
(
[
&
](
auto
&
self
,
auto
idx
)
{
v_g_n_o
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
3
],
idx
[
2
])
=
self
(
idx
);
});
[
&
](
auto
&
self
,
auto
idx
)
{
q_g_m_k
(
idx
[
0
]
*
G1Q
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
});
k_g_n_k
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
const
size_t
&
g0
=
idx
[
0
]
/
G1Q
;
const
size_t
&
g1q
=
idx
[
0
]
%
G1Q
;
const
size_t
&
g1kv
=
g1q
/
(
G1Q
/
G1KV
);
self
(
idx
)
=
k_gs_ns_ks
(
g0
,
g1kv
,
idx
[
1
],
idx
[
2
]);
});
v_g_n_o
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
const
size_t
&
g0
=
idx
[
0
]
/
G1Q
;
const
size_t
&
g1q
=
idx
[
0
]
%
G1Q
;
const
size_t
&
g1kv
=
g1q
/
(
G1Q
/
G1KV
);
self
(
idx
)
=
v_gs_os_ns
(
g0
,
g1kv
,
idx
[
2
],
idx
[
1
]);
});
// qkv gradients have the same descriptor as with qkv
DeviceMem
q_device_buf
(
sizeof
(
InputDataType
)
*
q_gs_ms_ks
.
mDesc
.
GetElementSpaceSize
());
...
...
@@ -491,8 +521,8 @@ int run(int argc, char* argv[])
DeviceMem
y_device_buf
(
sizeof
(
InputDataType
)
*
y_gs_ms_os
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
lse_device_buf
(
sizeof
(
LSEDataType
)
*
lse_gs_ms
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
qgrad_device_buf
(
sizeof
(
OutputDataType
)
*
q_gs_ms_ks
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
kgrad_device_buf
(
sizeof
(
OutputDataType
)
*
k_gs_ns_ks
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
vgrad_device_buf
(
sizeof
(
OutputDataType
)
*
v_gs_os_ns
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
kgrad_device_buf
(
sizeof
(
OutputDataType
)
*
k
grad
_gs_ns_ks
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
vgrad_device_buf
(
sizeof
(
OutputDataType
)
*
v
grad
_gs_os_ns
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
ygrad_device_buf
(
sizeof
(
InputDataType
)
*
y_gs_ms_os
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
d_device_buf
(
sizeof
(
DDataType
)
*
d_gs_ms
.
mDesc
.
GetElementSpaceSize
());
...
...
@@ -533,6 +563,10 @@ int run(int argc, char* argv[])
y_gs_ms_os_lengths
,
y_gs_ms_os_strides
,
lse_gs_ms_lengths
,
kgrad_gs_ns_ks_lengths
,
kgrad_gs_ns_ks_strides
,
vgrad_gs_os_ns_lengths
,
vgrad_gs_os_ns_strides
,
{},
// std::array<std::vector<ck::index_t>, 1>{acc0_biases_gs_ms_ns_lengths},
{},
// std::array<std::vector<ck::index_t>, 1>{acc0_biases_gs_ms_ns_strides},
{},
// std::array<std::vector<ck::index_t>, 1>{acc1_biases_gs_ms_os_lengths},
...
...
@@ -581,6 +615,10 @@ int run(int argc, char* argv[])
y_gs_ms_os_lengths
,
y_gs_ms_os_strides
,
lse_gs_ms_lengths
,
kgrad_gs_ns_ks_lengths
,
kgrad_gs_ns_ks_strides
,
vgrad_gs_os_ns_lengths
,
vgrad_gs_os_ns_strides
,
{},
// std::array<std::vector<ck::index_t>, 1>{acc0_biases_gs_ms_ns_lengths},
{},
// std::array<std::vector<ck::index_t>, 1>{acc0_biases_gs_ms_ns_strides},
{},
// std::array<std::vector<ck::index_t>, 1>{acc1_biases_gs_ms_os_lengths},
...
...
@@ -620,7 +658,7 @@ int run(int argc, char* argv[])
// copy z matirx data form device
z_device_buf
.
FromDevice
(
z_gs_ms_ns
.
mData
.
data
());
z_gs_ms_ns
.
ForEach
(
[
&
](
auto
&
self
,
auto
idx
)
{
z_g_m_n
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
});
[
&
](
auto
&
self
,
auto
idx
)
{
z_g_m_n
(
idx
[
0
]
*
G1
Q
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
});
// std::cout << "z_g_m_n ref:\n" << z_g_m_n;
bool
pass
=
true
;
...
...
@@ -640,10 +678,10 @@ int run(int argc, char* argv[])
p_dropout_in_uint8_t
,
rp_dropout
);
y_gs_ms_os
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
self
(
idx
)
=
y_g_m_o
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
]);
self
(
idx
)
=
y_g_m_o
(
idx
[
0
]
*
G1
Q
+
idx
[
1
],
idx
[
2
],
idx
[
3
]);
});
lse_gs_ms
.
ForEach
(
[
&
](
auto
&
self
,
auto
idx
)
{
self
(
idx
)
=
lse_g_m
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
]);
});
[
&
](
auto
&
self
,
auto
idx
)
{
self
(
idx
)
=
lse_g_m
(
idx
[
0
]
*
G1
Q
+
idx
[
1
],
idx
[
2
]);
});
y_device_buf
.
ToDevice
(
y_gs_ms_os
.
mData
.
data
());
lse_device_buf
.
ToDevice
(
lse_gs_ms
.
mData
.
data
());
...
...
@@ -661,7 +699,7 @@ int run(int argc, char* argv[])
Tensor
<
InputDataType
>
ygrad_dot_y_g_m
({
BatchCount
,
M
});
ygrad_gs_ms_os
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
ygrad_g_m_o
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
ygrad_g_m_o
(
idx
[
0
]
*
G1
Q
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
});
#if PRINT_HOST
...
...
@@ -763,12 +801,16 @@ int run(int argc, char* argv[])
#endif
Tensor
<
OutputDataType
>
qgrad_gs_ms_ks_host_result
(
q_gs_ms_ks_lengths
,
q_gs_ms_ks_strides
);
Tensor
<
OutputDataType
>
kgrad_gs_ns_ks_host_result
(
k_gs_ns_ks_lengths
,
k_gs_ns_ks_strides
);
Tensor
<
OutputDataType
>
vgrad_gs_os_ns_host_result
(
v_gs_os_ns_lengths
,
v_gs_os_ns_strides
);
Tensor
<
OutputDataType
>
kgrad_gs_ns_ks_host_result
(
kgrad_gs_ns_ks_lengths
,
kgrad_gs_ns_ks_strides
);
Tensor
<
OutputDataType
>
vgrad_gs_os_ns_host_result
(
vgrad_gs_os_ns_lengths
,
vgrad_gs_os_ns_strides
);
Tensor
<
OutputDataType
>
qgrad_gs_ms_ks_device_result
(
q_gs_ms_ks_lengths
,
q_gs_ms_ks_strides
);
Tensor
<
OutputDataType
>
kgrad_gs_ns_ks_device_result
(
k_gs_ns_ks_lengths
,
k_gs_ns_ks_strides
);
Tensor
<
OutputDataType
>
vgrad_gs_os_ns_device_result
(
v_gs_os_ns_lengths
,
v_gs_os_ns_strides
);
Tensor
<
OutputDataType
>
kgrad_gs_ns_ks_device_result
(
kgrad_gs_ns_ks_lengths
,
kgrad_gs_ns_ks_strides
);
Tensor
<
OutputDataType
>
vgrad_gs_os_ns_device_result
(
vgrad_gs_os_ns_lengths
,
vgrad_gs_os_ns_strides
);
qgrad_device_buf
.
FromDevice
(
qgrad_gs_ms_ks_device_result
.
mData
.
data
());
kgrad_device_buf
.
FromDevice
(
kgrad_gs_ns_ks_device_result
.
mData
.
data
());
...
...
@@ -776,26 +818,26 @@ int run(int argc, char* argv[])
// permute
qgrad_gs_ms_ks_host_result
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
const
size_t
&
g0
=
idx
[
0
];
const
size_t
&
g1
=
idx
[
1
];
const
size_t
&
g0
=
idx
[
0
];
const
size_t
&
g1
q
=
idx
[
1
];
const
size_t
g
=
g0
*
G1
+
g1
;
const
size_t
g
=
g0
*
G1
Q
+
g1
q
;
self
(
idx
)
=
qgrad_g_m_k
(
g
,
idx
[
2
],
idx
[
3
]);
});
kgrad_gs_ns_ks_host_result
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
const
size_t
&
g0
=
idx
[
0
];
const
size_t
&
g1
=
idx
[
1
];
const
size_t
&
g0
=
idx
[
0
];
const
size_t
&
g1
q
=
idx
[
1
];
const
size_t
g
=
g0
*
G1
+
g1
;
const
size_t
g
=
g0
*
G1
Q
+
g1
q
;
self
(
idx
)
=
kgrad_g_n_k
(
g
,
idx
[
2
],
idx
[
3
]);
});
vgrad_gs_os_ns_host_result
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
const
size_t
&
g0
=
idx
[
0
];
const
size_t
&
g1
=
idx
[
1
];
const
size_t
&
g0
=
idx
[
0
];
const
size_t
&
g1
q
=
idx
[
1
];
const
size_t
g
=
g0
*
G1
+
g1
;
const
size_t
g
=
g0
*
G1
Q
+
g1
q
;
self
(
idx
)
=
vgrad_g_n_o
(
g
,
idx
[
3
],
idx
[
2
]);
});
...
...
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_forward_v2.cpp
View file @
e87ddb0e
...
...
@@ -71,11 +71,10 @@ static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecializatio
static
constexpr
auto
MaskingSpec
=
ck
::
tensor_operation
::
device
::
MaskingSpecialization
::
MaskDisabled
;
static
constexpr
auto
TensorSpecA
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecB0
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecB1
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecC
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
bool
Deterministic
=
false
;
static
constexpr
auto
TensorSpecA
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecB0
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecB1
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecC
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
#if(DIM <= 32)
using
DeviceGemmInstance
=
...
...
@@ -149,8 +148,7 @@ using DeviceGemmInstance =
S
<
1
,
64
,
1
,
4
>
,
// CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8
,
// CShuffleBlockTransferScalarPerVector_NPerBlock
4
,
MaskingSpec
,
// MaskingSpecialization
Deterministic
>
;
MaskingSpec
>
;
// MaskingSpecialization
#elif(DIM <= 64)
using
DeviceGemmInstance
=
ck
::
tensor_operation
::
device
::
DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
<
...
...
@@ -223,8 +221,7 @@ using DeviceGemmInstance =
S
<
1
,
32
,
1
,
8
>
,
// CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8
,
// CShuffleBlockTransferScalarPerVector_NPerBlock
4
,
MaskingSpec
,
// MaskingSpecialization
Deterministic
>
;
MaskingSpec
>
;
// MaskingSpecialization
#elif(DIM <= 128)
using
DeviceGemmInstance
=
ck
::
tensor_operation
::
device
::
DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
<
...
...
@@ -297,8 +294,7 @@ using DeviceGemmInstance =
S
<
1
,
32
,
1
,
8
>
,
// CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8
,
// CShuffleBlockTransferScalarPerVector_NPerBlock
4
,
MaskingSpec
,
// MaskingSpecialization
Deterministic
>
;
MaskingSpec
>
;
// MaskingSpecialization
#endif
// Ref Gemm0: DataType in, AccDataType out
...
...
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_train_v2.cpp
View file @
e87ddb0e
...
...
@@ -113,11 +113,11 @@ static constexpr bool Deterministic = false;
#if(DIM <= 32)
// clang-format off
using
DeviceGemmInstanceFWD
=
// #################################################################################| NumDimG| NumDimM| NumDimN| NumDimK| NumDimO| ADataType| BDataType| B1DataType| CDataType| GemmDataType| ZDataType| LSEDataType| Acc0BiasDataType| Acc1BiasDataType| GemmAcc| CShuffle| A| B| Acc| B1| C| GEMM| ATensorSpec| B0TensorSpec| B1TensorSpec| CTensorSpec| NumGemmK| Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1|Dropout| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| D0BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CShuffleBlockTransferScalarPerVector| D1BlockTransfer| MaskingSpec|
Deterministic|
// #################################################################################| | | | | | | | | | | | | | | DataType| DataType| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| | | | | Prefetch| Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| Step| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcScalar| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| _NPerBlock| SrcScalar| |
|
// #################################################################################| | | | | | | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | | | | Stage| | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| | PerVector| |
|
// #################################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
ck
::
tensor_operation
::
device
::
DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
<
NumDimG
,
NumDimM
,
NumDimN
,
NumDimK
,
NumDimO
,
InputDataType
,
InputDataType
,
InputDataType
,
InputDataType
,
GemmDataType
,
ZDataType
,
LSEDataType
,
void
,
void
,
AccDataType
,
ShuffleDataType
,
QKVElementOp
,
QKVElementOp
,
Scale
,
QKVElementOp
,
YElementOp
,
GemmSpec
,
TensorSpecQ
,
TensorSpecK
,
TensorSpecV
,
TensorSpecY
,
1
,
256
,
128
,
128
,
32
,
32
,
32
,
8
,
8
,
2
,
32
,
32
,
1
,
4
,
1
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
4
,
S
<
16
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
2
,
false
,
1
,
1
,
S
<
1
,
64
,
1
,
4
>
,
8
,
4
,
MaskingSpec
,
Deterministic
>
;
// #################################################################################| NumDimG| NumDimM| NumDimN| NumDimK| NumDimO| ADataType| BDataType| B1DataType| CDataType| GemmDataType| ZDataType| LSEDataType| Acc0BiasDataType| Acc1BiasDataType| GemmAcc| CShuffle| A| B| Acc| B1| C| GEMM| ATensorSpec| B0TensorSpec| B1TensorSpec| CTensorSpec| NumGemmK| Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1|Dropout| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| D0BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CShuffleBlockTransferScalarPerVector| D1BlockTransfer| MaskingSpec|
// #################################################################################| | | | | | | | | | | | | | | DataType| DataType| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| | | | | Prefetch| Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| Step| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcScalar| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| _NPerBlock| SrcScalar| |
// #################################################################################| | | | | | | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | | | | Stage| | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| | PerVector| |
// #################################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
ck
::
tensor_operation
::
device
::
DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
<
NumDimG
,
NumDimM
,
NumDimN
,
NumDimK
,
NumDimO
,
InputDataType
,
InputDataType
,
InputDataType
,
InputDataType
,
GemmDataType
,
ZDataType
,
LSEDataType
,
void
,
void
,
AccDataType
,
ShuffleDataType
,
QKVElementOp
,
QKVElementOp
,
Scale
,
QKVElementOp
,
YElementOp
,
GemmSpec
,
TensorSpecQ
,
TensorSpecK
,
TensorSpecV
,
TensorSpecY
,
1
,
256
,
128
,
128
,
32
,
32
,
32
,
8
,
8
,
2
,
32
,
32
,
1
,
4
,
1
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
4
,
S
<
16
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
2
,
false
,
1
,
1
,
S
<
1
,
64
,
1
,
4
>
,
8
,
4
,
MaskingSpec
>
;
using
DeviceGemmInstanceBWD
=
// ########################################################################################| NumDimG| NumDimM| NumDimN| NumDimK| NumDimO| InputDataType| OutputDataType| GemmDataType| ZDataType| LSEDataType| Acc0BiasDataType| Acc1BiasDataType| GemmAcc| CShuffle| A| B| Acc| B1| C| GEMM| ATensorSpec| B0TensorSpec| B1TensorSpec| CTensorSpec| NumGemmK| Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| Gemm2| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1| Gemm2| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| D0BlockTransfer| CShuffle| CShuffle| CBlockTransferClusterLengths| CShuffleBlockTransferScalarPerVector_NPerBlock| MaskingSpec| Deterministic|
...
...
@@ -129,11 +129,11 @@ using DeviceGemmInstanceBWD =
#elif(DIM <= 64)
// clang-format off
using
DeviceGemmInstanceFWD
=
// #################################################################################| NumDimG| NumDimM| NumDimN| NumDimK| NumDimO| ADataType| BDataType| B1DataType| CDataType| GemmDataType| ZDataType| LSEDataType| Acc0BiasDataType| Acc1BiasDataType| GemmAcc| CShuffle| A| B| Acc| B1| C| GEMM| ATensorSpec| B0TensorSpec| B1TensorSpec| CTensorSpec| NumGemmK| Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1|Dropout| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| D0BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CShuffleBlockTransferScalarPerVector| D1BlockTransfer| MaskingSpec|
Deterministic|
// #################################################################################| | | | | | | | | | | | | | | DataType| DataType| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| | | | | Prefetch| Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| Step| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcScalar| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| _NPerBlock| SrcScalar| |
|
// #################################################################################| | | | | | | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | | | | Stage| | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| | PerVector| |
|
// #################################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
ck
::
tensor_operation
::
device
::
DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
<
NumDimG
,
NumDimM
,
NumDimN
,
NumDimK
,
NumDimO
,
InputDataType
,
InputDataType
,
InputDataType
,
InputDataType
,
GemmDataType
,
ZDataType
,
LSEDataType
,
void
,
void
,
AccDataType
,
ShuffleDataType
,
QKVElementOp
,
QKVElementOp
,
Scale
,
QKVElementOp
,
YElementOp
,
GemmSpec
,
TensorSpecQ
,
TensorSpecK
,
TensorSpecV
,
TensorSpecY
,
1
,
256
,
128
,
128
,
32
,
64
,
32
,
8
,
8
,
2
,
32
,
32
,
1
,
4
,
2
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
4
,
S
<
16
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
8
,
4
,
MaskingSpec
,
Deterministic
>
;
// #################################################################################| NumDimG| NumDimM| NumDimN| NumDimK| NumDimO| ADataType| BDataType| B1DataType| CDataType| GemmDataType| ZDataType| LSEDataType| Acc0BiasDataType| Acc1BiasDataType| GemmAcc| CShuffle| A| B| Acc| B1| C| GEMM| ATensorSpec| B0TensorSpec| B1TensorSpec| CTensorSpec| NumGemmK| Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1|Dropout| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| D0BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CShuffleBlockTransferScalarPerVector| D1BlockTransfer| MaskingSpec|
// #################################################################################| | | | | | | | | | | | | | | DataType| DataType| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| | | | | Prefetch| Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| Step| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcScalar| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| _NPerBlock| SrcScalar| |
// #################################################################################| | | | | | | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | | | | Stage| | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| | PerVector| |
// #################################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
ck
::
tensor_operation
::
device
::
DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
<
NumDimG
,
NumDimM
,
NumDimN
,
NumDimK
,
NumDimO
,
InputDataType
,
InputDataType
,
InputDataType
,
InputDataType
,
GemmDataType
,
ZDataType
,
LSEDataType
,
void
,
void
,
AccDataType
,
ShuffleDataType
,
QKVElementOp
,
QKVElementOp
,
Scale
,
QKVElementOp
,
YElementOp
,
GemmSpec
,
TensorSpecQ
,
TensorSpecK
,
TensorSpecV
,
TensorSpecY
,
1
,
256
,
128
,
128
,
32
,
64
,
32
,
8
,
8
,
2
,
32
,
32
,
1
,
4
,
2
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
4
,
S
<
16
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
8
,
4
,
MaskingSpec
>
;
using
DeviceGemmInstanceBWD
=
// ########################################################################################| NumDimG| NumDimM| NumDimN| NumDimK| NumDimO| InputDataType| OutputDataType| GemmDataType| ZDataType| LSEDataType| Acc0BiasDataType| Acc1BiasDataType| GemmAcc| CShuffle| A| B| Acc| B1| C| GEMM| ATensorSpec| B0TensorSpec| B1TensorSpec| CTensorSpec| NumGemmK| Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| Gemm2| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1| Gemm2| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| D0BlockTransfer| CShuffle| CShuffle| CBlockTransferClusterLengths| CShuffleBlockTransferScalarPerVector_NPerBlock| MaskingSpec| Deterministic|
...
...
@@ -153,11 +153,11 @@ using DeviceGemmInstanceBWD =
#elif(DIM <= 128)
// clang-format off
using
DeviceGemmInstanceFWD
=
// #################################################################################| NumDimG| NumDimM| NumDimN| NumDimK| NumDimO| ADataType| BDataType| B1DataType| CDataType| GemmDataType| ZDataType| LSEDataType| Acc0BiasDataType| Acc1BiasDataType| GemmAcc| CShuffle| A| B| Acc| B1| C| GEMM| ATensorSpec| B0TensorSpec| B1TensorSpec| CTensorSpec| NumGemmK| Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1|Dropout| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| D0BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CShuffleBlockTransferScalarPerVector| D1BlockTransfer| MaskingSpec|
Deterministic|
// #################################################################################| | | | | | | | | | | | | | | DataType| DataType| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| | | | | Prefetch| Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| Step| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcScalar| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| _NPerBlock| SrcScalar| |
|
// #################################################################################| | | | | | | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | | | | Stage| | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| | PerVector| |
|
// #################################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
ck
::
tensor_operation
::
device
::
DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
<
NumDimG
,
NumDimM
,
NumDimN
,
NumDimK
,
NumDimO
,
InputDataType
,
InputDataType
,
InputDataType
,
InputDataType
,
GemmDataType
,
ZDataType
,
LSEDataType
,
void
,
void
,
AccDataType
,
ShuffleDataType
,
QKVElementOp
,
QKVElementOp
,
Scale
,
QKVElementOp
,
YElementOp
,
GemmSpec
,
TensorSpecQ
,
TensorSpecK
,
TensorSpecV
,
TensorSpecY
,
1
,
256
,
128
,
128
,
32
,
64
,
32
,
8
,
8
,
2
,
32
,
32
,
1
,
4
,
2
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
4
,
S
<
16
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
8
,
4
,
MaskingSpec
,
Deterministic
>
;
// #################################################################################| NumDimG| NumDimM| NumDimN| NumDimK| NumDimO| ADataType| BDataType| B1DataType| CDataType| GemmDataType| ZDataType| LSEDataType| Acc0BiasDataType| Acc1BiasDataType| GemmAcc| CShuffle| A| B| Acc| B1| C| GEMM| ATensorSpec| B0TensorSpec| B1TensorSpec| CTensorSpec| NumGemmK| Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1|Dropout| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| D0BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CShuffleBlockTransferScalarPerVector| D1BlockTransfer| MaskingSpec|
// #################################################################################| | | | | | | | | | | | | | | DataType| DataType| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| | | | | Prefetch| Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| Step| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcScalar| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| _NPerBlock| SrcScalar| |
// #################################################################################| | | | | | | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | | | | Stage| | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| | PerVector| |
// #################################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
ck
::
tensor_operation
::
device
::
DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
<
NumDimG
,
NumDimM
,
NumDimN
,
NumDimK
,
NumDimO
,
InputDataType
,
InputDataType
,
InputDataType
,
InputDataType
,
GemmDataType
,
ZDataType
,
LSEDataType
,
void
,
void
,
AccDataType
,
ShuffleDataType
,
QKVElementOp
,
QKVElementOp
,
Scale
,
QKVElementOp
,
YElementOp
,
GemmSpec
,
TensorSpecQ
,
TensorSpecK
,
TensorSpecV
,
TensorSpecY
,
1
,
256
,
128
,
128
,
32
,
64
,
32
,
8
,
8
,
2
,
32
,
32
,
1
,
4
,
2
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
4
,
S
<
16
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
8
,
4
,
MaskingSpec
>
;
using
DeviceGemmInstanceBWD
=
// ########################################################################################| NumDimG| NumDimM| NumDimN| NumDimK| NumDimO| InputDataType| OutputDataType| GemmDataType| ZDataType| LSEDataType| Acc0BiasDataType| Acc1BiasDataType| GemmAcc| CShuffle| A| B| Acc| B1| C| GEMM| ATensorSpec| B0TensorSpec| B1TensorSpec| CTensorSpec| NumGemmK| Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| Gemm2| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1| Gemm2| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| D0BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CShuffleBlockTransferScalarPerVector_NPerBlock| MaskingSpec| Deterministic|
...
...
@@ -299,14 +299,15 @@ int run(int argc, char* argv[])
// Overall QKV matrices shape
// y_g_m_o = Softmax(alpha * Q_g_m_k * K_g_k_n) * V_g_n_o
// y_g0_g1_m_o = reshape(y_g_m_o, [G0, G1, M, O])
// y_g0_m_g1_o = permute(y_g0_g1_m_o, [0, 2, 1, 3])
ck
::
index_t
N
=
500
;
// 512
ck
::
index_t
M
=
500
;
// 512
ck
::
index_t
K
=
DIM
;
ck
::
index_t
O
=
DIM
;
ck
::
index_t
G0
=
4
;
// 54
ck
::
index_t
G1
=
6
;
// 16
// y_g0_g1q_m_o = reshape(y_g_m_o, [G0, G1Q, M, O])
// y_g0_m_g1q_o = permute(y_g0_g1q_m_o, [0, 2, 1, 3])
ck
::
index_t
N
=
500
;
// 512
ck
::
index_t
M
=
500
;
// 512
ck
::
index_t
K
=
DIM
;
ck
::
index_t
O
=
DIM
;
ck
::
index_t
G0
=
4
;
ck
::
index_t
G1Q
=
6
;
// h_q
ck
::
index_t
G1KV
=
6
;
// h_kv
bool
input_permute
=
false
;
bool
output_permute
=
false
;
...
...
@@ -325,32 +326,33 @@ int run(int argc, char* argv[])
init_method
=
std
::
stoi
(
argv
[
2
]);
time_kernel
=
std
::
stoi
(
argv
[
3
]);
}
else
if
(
argc
==
1
3
)
else
if
(
argc
==
1
4
)
{
do_verification
=
std
::
stoi
(
argv
[
1
]);
init_method
=
std
::
stoi
(
argv
[
2
]);
time_kernel
=
std
::
stoi
(
argv
[
3
]);
M
=
std
::
stoi
(
argv
[
4
]);
N
=
std
::
stoi
(
argv
[
5
]);
K
=
std
::
stoi
(
argv
[
6
]);
O
=
std
::
stoi
(
argv
[
7
]);
G0
=
std
::
stoi
(
argv
[
8
]);
G1
=
std
::
stoi
(
argv
[
9
]);
M
=
std
::
stoi
(
argv
[
4
]);
N
=
std
::
stoi
(
argv
[
5
]);
K
=
std
::
stoi
(
argv
[
6
]);
O
=
std
::
stoi
(
argv
[
7
]);
G0
=
std
::
stoi
(
argv
[
8
]);
G1Q
=
std
::
stoi
(
argv
[
9
]);
G1KV
=
std
::
stoi
(
argv
[
10
]);
p_drop
=
std
::
stof
(
argv
[
1
0
]);
p_drop
=
std
::
stof
(
argv
[
1
1
]);
input_permute
=
std
::
stoi
(
argv
[
1
1
]);
output_permute
=
std
::
stoi
(
argv
[
1
2
]);
input_permute
=
std
::
stoi
(
argv
[
1
2
]);
output_permute
=
std
::
stoi
(
argv
[
1
3
]);
}
else
{
printf
(
"arg1: verification (0=no, 1=yes)
\n
"
);
printf
(
"arg2: initialization (0=no init, 1=integer value, 2=decimal value)
\n
"
);
printf
(
"arg3: time kernel (0=no, 1=yes)
\n
"
);
printf
(
"arg4 to 1
1
: M, N, K, O, G0, G1
\n
"
);
printf
(
"arg1
0
:
scale (alpha)
\n
"
);
printf
(
"arg1
1
to 1
2
: input / output permute
\n
"
);
printf
(
"arg4 to 1
0
: M, N, K, O, G0, G1
Q, G1KV
\n
"
);
printf
(
"arg1
1
:
p_drop
\n
"
);
printf
(
"arg1
2
to 1
3
: input / output permute
\n
"
);
exit
(
0
);
}
...
...
@@ -367,7 +369,8 @@ int run(int argc, char* argv[])
std
::
cout
<<
"K: "
<<
K
<<
std
::
endl
;
std
::
cout
<<
"O: "
<<
O
<<
std
::
endl
;
std
::
cout
<<
"G0: "
<<
G0
<<
std
::
endl
;
std
::
cout
<<
"G1: "
<<
G1
<<
std
::
endl
;
std
::
cout
<<
"G1Q: "
<<
G1Q
<<
std
::
endl
;
std
::
cout
<<
"G1KV: "
<<
G1KV
<<
std
::
endl
;
std
::
cout
<<
"alpha: "
<<
alpha
<<
std
::
endl
;
std
::
cout
<<
"input_permute: "
<<
input_permute
<<
std
::
endl
;
std
::
cout
<<
"output_permute: "
<<
output_permute
<<
std
::
endl
;
...
...
@@ -375,45 +378,57 @@ int run(int argc, char* argv[])
std
::
cout
<<
"seed: "
<<
seed
<<
std
::
endl
;
std
::
cout
<<
"offset: "
<<
offset
<<
std
::
endl
;
const
ck
::
index_t
BatchCount
=
G0
*
G1
;
const
ck
::
index_t
BatchCount
=
G0
*
G1
Q
;
std
::
vector
<
ck
::
index_t
>
q_gs_ms_ks_lengths
{
G0
,
G1
,
M
,
K
};
std
::
vector
<
ck
::
index_t
>
q_gs_ms_ks_lengths
{
G0
,
G1
Q
,
M
,
K
};
std
::
vector
<
ck
::
index_t
>
q_gs_ms_ks_strides
=
input_permute
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
*
K
,
K
,
G1
*
K
,
1
}
// Q layout [G0, M, G1, K]
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
M
*
K
,
M
*
K
,
K
,
1
};
// Q layout [G0, G1, M, K]
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
Q
*
K
,
K
,
G1
Q
*
K
,
1
}
// Q layout [G0, M, G1
Q
, K]
:
std
::
vector
<
ck
::
index_t
>
{
G1
Q
*
M
*
K
,
M
*
K
,
K
,
1
};
// Q layout [G0, G1
Q
, M, K]
std
::
vector
<
ck
::
index_t
>
k_gs_ns_ks_lengths
{
G0
,
G1
,
N
,
K
};
std
::
vector
<
ck
::
index_t
>
k_gs_ns_ks_lengths
{
G0
,
G1
KV
,
N
,
K
};
std
::
vector
<
ck
::
index_t
>
k_gs_ns_ks_strides
=
input_permute
?
std
::
vector
<
ck
::
index_t
>
{
N
*
G1
*
K
,
K
,
G1
*
K
,
1
}
// K layout [G0, N, G1, K]
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
N
*
K
,
N
*
K
,
K
,
1
};
// K layout [G0, G1, N, K]
?
std
::
vector
<
ck
::
index_t
>
{
N
*
G1
KV
*
K
,
K
,
G1
KV
*
K
,
1
}
// K layout [G0, N, G1
KV
, K]
:
std
::
vector
<
ck
::
index_t
>
{
G1
KV
*
N
*
K
,
N
*
K
,
K
,
1
};
// K layout [G0, G1
KV
, N, K]
std
::
vector
<
ck
::
index_t
>
v_gs_os_ns_lengths
{
G0
,
G1
,
O
,
N
};
std
::
vector
<
ck
::
index_t
>
v_gs_os_ns_lengths
{
G0
,
G1
KV
,
O
,
N
};
std
::
vector
<
ck
::
index_t
>
v_gs_os_ns_strides
=
input_permute
?
std
::
vector
<
ck
::
index_t
>
{
N
*
G1
*
O
,
O
,
1
,
G1
*
O
}
// V layout [G0, N, G1, O]
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
N
*
O
,
N
*
O
,
1
,
O
};
// V layout [G0, G1, N, O]
?
std
::
vector
<
ck
::
index_t
>
{
N
*
G1
KV
*
O
,
O
,
1
,
G1
KV
*
O
}
// V layout [G0, N, G1
KV
, O]
:
std
::
vector
<
ck
::
index_t
>
{
G1
KV
*
N
*
O
,
N
*
O
,
1
,
O
};
// V layout [G0, G1
KV
, N, O]
std
::
vector
<
ck
::
index_t
>
y_gs_ms_os_lengths
{
G0
,
G1
,
M
,
O
};
std
::
vector
<
ck
::
index_t
>
y_gs_ms_os_lengths
{
G0
,
G1
Q
,
M
,
O
};
std
::
vector
<
ck
::
index_t
>
y_gs_ms_os_strides
=
output_permute
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
*
O
,
O
,
G1
*
O
,
1
}
// Y layout [G0, M, G1, O]
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
M
*
O
,
M
*
O
,
O
,
1
};
// Y layout [G0, G1, M, O]
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
Q
*
O
,
O
,
G1
Q
*
O
,
1
}
// Y layout [G0, M, G1
Q
, O]
:
std
::
vector
<
ck
::
index_t
>
{
G1
Q
*
M
*
O
,
M
*
O
,
O
,
1
};
// Y layout [G0, G1
Q
, M, O]
std
::
vector
<
ck
::
index_t
>
z_gs_ms_ns_lengths
{
G0
,
G1
,
M
,
N
};
std
::
vector
<
ck
::
index_t
>
z_gs_ms_ns_lengths
{
G0
,
G1
Q
,
M
,
N
};
std
::
vector
<
ck
::
index_t
>
z_gs_ms_ns_strides
=
input_permute
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
*
N
,
N
,
G1
*
N
,
1
}
// Z layout [G0, M, G1, N]
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
M
*
N
,
M
*
N
,
N
,
1
};
// Z layout [G0, G1, M, N]
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1Q
*
N
,
N
,
G1Q
*
N
,
1
}
// Z layout [G0, M, G1Q, N]
:
std
::
vector
<
ck
::
index_t
>
{
G1Q
*
M
*
N
,
M
*
N
,
N
,
1
};
// Z layout [G0, G1Q, M, N]
std
::
vector
<
ck
::
index_t
>
kgrad_gs_ns_ks_lengths
{
G0
,
G1Q
,
N
,
K
};
std
::
vector
<
ck
::
index_t
>
kgrad_gs_ns_ks_strides
=
input_permute
?
std
::
vector
<
ck
::
index_t
>
{
N
*
G1Q
*
K
,
K
,
G1Q
*
K
,
1
}
// KGrad layout [G0, N, G1Q, K]
:
std
::
vector
<
ck
::
index_t
>
{
G1Q
*
N
*
K
,
N
*
K
,
K
,
1
};
// KGrad layout [G0, G1Q, N, K]
std
::
vector
<
ck
::
index_t
>
vgrad_gs_os_ns_lengths
{
G0
,
G1Q
,
O
,
N
};
std
::
vector
<
ck
::
index_t
>
vgrad_gs_os_ns_strides
=
input_permute
?
std
::
vector
<
ck
::
index_t
>
{
N
*
G1Q
*
O
,
O
,
1
,
G1Q
*
O
}
// VGrad layout [G0, N, G1Q, O]
:
std
::
vector
<
ck
::
index_t
>
{
G1Q
*
N
*
O
,
N
*
O
,
1
,
O
};
// VGrad layout [G0, G1Q, N, O]
// The softmax stat log-sum-exp (LSE) is used to speed up softmax calculation in backward pass
// Pi = exp(Si) / sum(exp(S0) + exp(S1) + ...)
// = exp(Si) / exp(log(sum(exp() + ...)))
// = exp(Si - log(sum(exp() + ...)))
// ^^^^^^^^^^^^^^^^^^^^^
// LSE
std
::
vector
<
ck
::
index_t
>
lse_gs_ms_lengths
{
G0
,
G1
,
M
};
std
::
vector
<
ck
::
index_t
>
lse_gs_ms_strides
{
G1
*
M
,
M
,
1
};
// LSE layout [G0, G1, M]
std
::
vector
<
ck
::
index_t
>
lse_gs_ms_lengths
{
G0
,
G1
Q
,
M
};
std
::
vector
<
ck
::
index_t
>
lse_gs_ms_strides
{
G1
Q
*
M
,
M
,
1
};
// LSE layout [G0, G1
Q
, M]
Tensor
<
InputDataType
>
q_gs_ms_ks
(
q_gs_ms_ks_lengths
,
q_gs_ms_ks_strides
);
Tensor
<
InputDataType
>
k_gs_ns_ks
(
k_gs_ns_ks_lengths
,
k_gs_ns_ks_strides
);
...
...
@@ -424,8 +439,10 @@ int run(int argc, char* argv[])
Tensor
<
InputDataType
>
y_gs_ms_os_device_result
(
y_gs_ms_os_lengths
,
y_gs_ms_os_strides
);
Tensor
<
LSEDataType
>
lse_gs_ms_device_result
(
lse_gs_ms_lengths
,
lse_gs_ms_strides
);
Tensor
<
OutputDataType
>
qgrad_gs_ms_ks_device_result
(
q_gs_ms_ks_lengths
,
q_gs_ms_ks_strides
);
Tensor
<
OutputDataType
>
kgrad_gs_ns_ks_device_result
(
k_gs_ns_ks_lengths
,
k_gs_ns_ks_strides
);
Tensor
<
OutputDataType
>
vgrad_gs_os_ns_device_result
(
v_gs_os_ns_lengths
,
v_gs_os_ns_strides
);
Tensor
<
OutputDataType
>
kgrad_gs_ns_ks_device_result
(
kgrad_gs_ns_ks_lengths
,
kgrad_gs_ns_ks_strides
);
Tensor
<
OutputDataType
>
vgrad_gs_os_ns_device_result
(
vgrad_gs_os_ns_lengths
,
vgrad_gs_os_ns_strides
);
std
::
cout
<<
"q_gs_ms_ks: "
<<
q_gs_ms_ks
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"k_gs_ns_ks: "
<<
k_gs_ns_ks
.
mDesc
<<
std
::
endl
;
...
...
@@ -467,14 +484,14 @@ int run(int argc, char* argv[])
q_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_1
<
InputDataType
>
{
1
});
k_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
InputDataType
>
{});
v_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
InputDataType
>
{});
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_Sequential
<
2
>
{});
// dy[g0, g1, m, o]
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_Sequential
<
2
>
{});
// dy[g0, g1
q
, m, o]
// dO dot O = [0; 1; 2; ...]
break
;
case
6
:
q_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_1
<
InputDataType
>
{
1
});
k_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
InputDataType
>
{});
v_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
InputDataType
>
{});
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_Sequential
<
3
>
{});
// dy[g0, g1, m, o]
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_Sequential
<
3
>
{});
// dy[g0, g1
q
, m, o]
// assume mnko = 256
// P = softmax(QK) = 0.0039 * ones
// O = P V = 0.0039 * ones
...
...
@@ -487,7 +504,8 @@ int run(int argc, char* argv[])
q_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_1
<
InputDataType
>
{
1
});
k_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
InputDataType
>
{});
v_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
InputDataType
>
{});
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_1
<
InputDataType
>
{
1
});
// dy[g0, g1, m, o]
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_1
<
InputDataType
>
{
1
});
// dy[g0, g1q, m, o]
// assume mnko = 256
// P = softmax(QK) = 0.0039 * ones
// O = P V = 0.0039 * ones
...
...
@@ -612,6 +630,10 @@ int run(int argc, char* argv[])
y_gs_ms_os_lengths
,
y_gs_ms_os_strides
,
lse_gs_ms_lengths
,
kgrad_gs_ns_ks_lengths
,
kgrad_gs_ns_ks_strides
,
vgrad_gs_os_ns_lengths
,
vgrad_gs_os_ns_strides
,
{},
// std::array<std::vector<ck::index_t>, 1>{acc0_biases_gs_ms_ns_lengths},
{},
// std::array<std::vector<ck::index_t>, 1>{acc0_biases_gs_ms_ns_strides},
{},
// std::array<std::vector<ck::index_t>, 1>{acc1_biases_gs_ms_os_lengths},
...
...
@@ -656,8 +678,10 @@ int run(int argc, char* argv[])
Tensor
<
InputDataType
>
y_gs_ms_os_host_result
(
y_gs_ms_os_lengths
,
y_gs_ms_os_strides
);
Tensor
<
LSEDataType
>
lse_gs_ms_host_result
(
lse_gs_ms_lengths
,
lse_gs_ms_strides
);
Tensor
<
OutputDataType
>
qgrad_gs_ms_ks_host_result
(
q_gs_ms_ks_lengths
,
q_gs_ms_ks_strides
);
Tensor
<
OutputDataType
>
kgrad_gs_ns_ks_host_result
(
k_gs_ns_ks_lengths
,
k_gs_ns_ks_strides
);
Tensor
<
OutputDataType
>
vgrad_gs_os_ns_host_result
(
v_gs_os_ns_lengths
,
v_gs_os_ns_strides
);
Tensor
<
OutputDataType
>
kgrad_gs_ns_ks_host_result
(
kgrad_gs_ns_ks_lengths
,
kgrad_gs_ns_ks_strides
);
Tensor
<
OutputDataType
>
vgrad_gs_os_ns_host_result
(
vgrad_gs_os_ns_lengths
,
vgrad_gs_os_ns_strides
);
Tensor
<
InputDataType
>
q_g_m_k
({
BatchCount
,
M
,
K
});
Tensor
<
InputDataType
>
k_g_n_k
({
BatchCount
,
N
,
K
});
...
...
@@ -760,6 +784,10 @@ int run(int argc, char* argv[])
y_gs_ms_os_lengths
,
y_gs_ms_os_strides
,
lse_gs_ms_lengths
,
kgrad_gs_ns_ks_lengths
,
kgrad_gs_ns_ks_strides
,
vgrad_gs_os_ns_lengths
,
vgrad_gs_os_ns_strides
,
{},
// std::array<std::vector<ck::index_t>, 1>{acc0_biases_gs_ms_ns_lengths},
{},
// std::array<std::vector<ck::index_t>, 1>{acc0_biases_gs_ms_ns_strides},
{},
// std::array<std::vector<ck::index_t>, 1>{acc1_biases_gs_ms_os_lengths},
...
...
@@ -793,16 +821,24 @@ int run(int argc, char* argv[])
}
q_gs_ms_ks
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
q_g_m_k
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
q_g_m_k
(
idx
[
0
]
*
G1
Q
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
});
k_gs_ns_ks
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
k_g_n_k
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
k_g_n_k
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
const
size_t
&
g0
=
idx
[
0
]
/
G1Q
;
const
size_t
&
g1q
=
idx
[
0
]
%
G1Q
;
const
size_t
&
g1kv
=
g1q
/
(
G1Q
/
G1KV
);
self
(
idx
)
=
k_gs_ns_ks
(
g0
,
g1kv
,
idx
[
1
],
idx
[
2
]);
});
v_gs_os_ns
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
v_g_n_o
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
3
],
idx
[
2
])
=
self
(
idx
);
v_g_n_o
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
const
size_t
&
g0
=
idx
[
0
]
/
G1Q
;
const
size_t
&
g1q
=
idx
[
0
]
%
G1Q
;
const
size_t
&
g1kv
=
g1q
/
(
G1Q
/
G1KV
);
self
(
idx
)
=
v_gs_os_ns
(
g0
,
g1kv
,
idx
[
2
],
idx
[
1
]);
});
z_fwd_gs_ms_ns
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
z_fwd_g_m_n
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
z_fwd_g_m_n
(
idx
[
0
]
*
G1
Q
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
});
run_attention_fwd_host
(
q_g_m_k
,
...
...
@@ -819,10 +855,10 @@ int run(int argc, char* argv[])
rp_dropout
);
ygrad_gs_ms_os
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
ygrad_g_m_o
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
ygrad_g_m_o
(
idx
[
0
]
*
G1
Q
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
});
z_bwd_gs_ms_ns
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
z_bwd_g_m_n
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
z_bwd_g_m_n
(
idx
[
0
]
*
G1
Q
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
});
#if PRINT_HOST
...
...
@@ -925,42 +961,42 @@ int run(int argc, char* argv[])
// permute
y_gs_ms_os_host_result
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
const
size_t
&
g0
=
idx
[
0
];
const
size_t
&
g1
=
idx
[
1
];
const
size_t
&
g0
=
idx
[
0
];
const
size_t
&
g1
q
=
idx
[
1
];
const
size_t
g
=
g0
*
G1
+
g1
;
const
size_t
g
=
g0
*
G1
Q
+
g1
q
;
self
(
idx
)
=
y_g_m_o
(
g
,
idx
[
2
],
idx
[
3
]);
});
lse_gs_ms_host_result
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
const
size_t
&
g0
=
idx
[
0
];
const
size_t
&
g1
=
idx
[
1
];
const
size_t
&
g0
=
idx
[
0
];
const
size_t
&
g1
q
=
idx
[
1
];
const
size_t
g
=
g0
*
G1
+
g1
;
const
size_t
g
=
g0
*
G1
Q
+
g1
q
;
self
(
idx
)
=
lse_g_m
(
g
,
idx
[
2
]);
});
qgrad_gs_ms_ks_host_result
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
const
size_t
&
g0
=
idx
[
0
];
const
size_t
&
g1
=
idx
[
1
];
const
size_t
&
g0
=
idx
[
0
];
const
size_t
&
g1
q
=
idx
[
1
];
const
size_t
g
=
g0
*
G1
+
g1
;
const
size_t
g
=
g0
*
G1
Q
+
g1
q
;
self
(
idx
)
=
qgrad_g_m_k
(
g
,
idx
[
2
],
idx
[
3
]);
});
kgrad_gs_ns_ks_host_result
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
const
size_t
&
g0
=
idx
[
0
];
const
size_t
&
g1
=
idx
[
1
];
const
size_t
&
g0
=
idx
[
0
];
const
size_t
&
g1
q
=
idx
[
1
];
const
size_t
g
=
g0
*
G1
+
g1
;
const
size_t
g
=
g0
*
G1
Q
+
g1
q
;
self
(
idx
)
=
kgrad_g_n_k
(
g
,
idx
[
2
],
idx
[
3
]);
});
vgrad_gs_os_ns_host_result
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
const
size_t
&
g0
=
idx
[
0
];
const
size_t
&
g1
=
idx
[
1
];
const
size_t
&
g0
=
idx
[
0
];
const
size_t
&
g1
q
=
idx
[
1
];
const
size_t
g
=
g0
*
G1
+
g1
;
const
size_t
g
=
g0
*
G1
Q
+
g1
q
;
self
(
idx
)
=
vgrad_g_n_o
(
g
,
idx
[
3
],
idx
[
2
]);
});
...
...
example/32_batched_gemm_scale_softmax_gemm/grouped_multihead_attention_backward_v2.cpp
View file @
e87ddb0e
...
...
@@ -268,10 +268,11 @@ int run(int argc, char* argv[])
// Overall QKV matrices shape
// y_g_m_o = Softmax(alpha * Q_g_m_k * K_g_k_n) * V_g_n_o
// y_g0_g1_m_o = reshape(y_g_m_o, [G0, G1, M, O])
// y_g0_m_g1_o = permute(y_g0_g1_m_o, [0, 2, 1, 3])
// y_g0_g1
q
_m_o = reshape(y_g_m_o, [G0, G1
Q
, M, O])
// y_g0_m_g1
q
_o = permute(y_g0_g1
q
_m_o, [0, 2, 1, 3])
float
alpha
=
1.
f
/
std
::
sqrt
(
DIM
);
float
p_drop
=
0.0
;
int
h_ratio
=
1
;
// G1Q / G1KV
bool
input_permute
=
true
;
bool
output_permute
=
true
;
...
...
@@ -289,25 +290,26 @@ int run(int argc, char* argv[])
init_method
=
std
::
stoi
(
argv
[
2
]);
time_kernel
=
std
::
stoi
(
argv
[
3
]);
}
else
if
(
argc
==
7
)
else
if
(
argc
==
8
)
{
do_verification
=
std
::
stoi
(
argv
[
1
]);
init_method
=
std
::
stoi
(
argv
[
2
]);
time_kernel
=
std
::
stoi
(
argv
[
3
]);
p_drop
=
std
::
stof
(
argv
[
4
]);
p_drop
=
std
::
stof
(
argv
[
4
]);
h_ratio
=
std
::
stof
(
argv
[
5
]);
input_permute
=
std
::
stoi
(
argv
[
5
]);
output_permute
=
std
::
stoi
(
argv
[
6
]);
input_permute
=
std
::
stoi
(
argv
[
6
]);
output_permute
=
std
::
stoi
(
argv
[
7
]);
}
else
{
printf
(
"arg1: verification (0=no, 1=yes)
\n
"
);
printf
(
"arg2: initialization (0=no init, 1=integer value, 2=decimal value)
\n
"
);
printf
(
"arg3: time kernel (0=no, 1=yes)
\n
"
);
printf
(
"arg4
to 11: M, N, K, O, G0, G1
\n
"
);
printf
(
"arg
10: scale (alpha)
\n
"
);
printf
(
"arg
11
to
12
: input / output permute
\n
"
);
printf
(
"arg4
: p_drop
\n
"
);
printf
(
"arg
5: h_ratio
\n
"
);
printf
(
"arg
6
to
7
: input / output permute
\n
"
);
exit
(
0
);
}
...
...
@@ -367,49 +369,65 @@ int run(int argc, char* argv[])
std
::
size_t
flop
=
0
,
num_byte
=
0
;
for
(
std
::
size_t
i
=
0
;
i
<
group_count
;
i
++
)
{
int
M
=
128
*
(
rand
()
%
8
)
+
(
rand
()
%
128
);
int
N
=
128
*
(
rand
()
%
8
)
+
(
rand
()
%
128
);
int
K
=
DIM
;
int
O
=
DIM
;
int
G0
=
rand
()
%
4
+
1
;
int
G1
=
rand
()
%
4
+
1
;
std
::
vector
<
ck
::
index_t
>
q_gs_ms_ks_lengths
{
G0
,
G1
,
M
,
K
};
int
M
=
128
*
(
rand
()
%
8
)
+
(
rand
()
%
128
);
int
N
=
128
*
(
rand
()
%
8
)
+
(
rand
()
%
128
);
int
K
=
DIM
;
int
O
=
DIM
;
int
G0
=
rand
()
%
4
+
1
;
int
G1KV
=
rand
()
%
4
+
1
;
int
G1Q
=
G1KV
*
h_ratio
;
std
::
vector
<
ck
::
index_t
>
q_gs_ms_ks_lengths
{
G0
,
G1Q
,
M
,
K
};
std
::
vector
<
ck
::
index_t
>
q_gs_ms_ks_strides
=
input_permute
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
*
K
,
K
,
G1
*
K
,
1
}
// Q layout [G0, M, G1, K]
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
M
*
K
,
M
*
K
,
K
,
1
};
// Q layout [G0, G1, M, K]
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
Q
*
K
,
K
,
G1
Q
*
K
,
1
}
// Q layout [G0, M, G1
Q
, K]
:
std
::
vector
<
ck
::
index_t
>
{
G1
Q
*
M
*
K
,
M
*
K
,
K
,
1
};
// Q layout [G0, G1
Q
, M, K]
std
::
vector
<
ck
::
index_t
>
k_gs_ns_ks_lengths
{
G0
,
G1
,
N
,
K
};
std
::
vector
<
ck
::
index_t
>
k_gs_ns_ks_lengths
{
G0
,
G1
KV
,
N
,
K
};
std
::
vector
<
ck
::
index_t
>
k_gs_ns_ks_strides
=
input_permute
?
std
::
vector
<
ck
::
index_t
>
{
N
*
G1
*
K
,
K
,
G1
*
K
,
1
}
// K layout [G0, N, G1, K]
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
N
*
K
,
N
*
K
,
K
,
1
};
// K layout [G0, G1, N, K]
?
std
::
vector
<
ck
::
index_t
>
{
N
*
G1KV
*
K
,
K
,
G1KV
*
K
,
1
}
// K layout [G0, N, G1KV, K]
:
std
::
vector
<
ck
::
index_t
>
{
G1KV
*
N
*
K
,
N
*
K
,
K
,
1
};
// K layout [G0, G1KV, N, K]
std
::
vector
<
ck
::
index_t
>
v_gs_os_ns_lengths
{
G0
,
G1
,
O
,
N
};
std
::
vector
<
ck
::
index_t
>
v_gs_os_ns_lengths
{
G0
,
G1
KV
,
O
,
N
};
std
::
vector
<
ck
::
index_t
>
v_gs_os_ns_strides
=
input_permute
?
std
::
vector
<
ck
::
index_t
>
{
N
*
G1
*
O
,
O
,
1
,
G1
*
O
}
// V layout [G0, N, G1, O]
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
N
*
O
,
N
*
O
,
1
,
O
};
// V layout [G0, G1, N, O]
?
std
::
vector
<
ck
::
index_t
>
{
N
*
G1KV
*
O
,
O
,
1
,
G1KV
*
O
}
// V layout [G0, N, G1KV, O]
:
std
::
vector
<
ck
::
index_t
>
{
G1KV
*
N
*
O
,
N
*
O
,
1
,
O
};
// V layout [G0, G1KV, N, O]
std
::
vector
<
ck
::
index_t
>
y_gs_ms_os_lengths
{
G0
,
G1
,
M
,
O
};
std
::
vector
<
ck
::
index_t
>
y_gs_ms_os_lengths
{
G0
,
G1
Q
,
M
,
O
};
std
::
vector
<
ck
::
index_t
>
y_gs_ms_os_strides
=
output_permute
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
*
O
,
O
,
G1
*
O
,
1
}
// Y layout [G0, M, G1, O]
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
M
*
O
,
M
*
O
,
O
,
1
};
// Y layout [G0, G1, M, O]
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
Q
*
O
,
O
,
G1
Q
*
O
,
1
}
// Y layout [G0, M, G1
Q
, O]
:
std
::
vector
<
ck
::
index_t
>
{
G1
Q
*
M
*
O
,
M
*
O
,
O
,
1
};
// Y layout [G0, G1
Q
, M, O]
std
::
vector
<
ck
::
index_t
>
z_gs_ms_ns_lengths
{
G0
,
G1
,
M
,
N
};
std
::
vector
<
ck
::
index_t
>
z_gs_ms_ns_lengths
{
G0
,
G1
Q
,
M
,
N
};
std
::
vector
<
ck
::
index_t
>
z_gs_ms_ns_strides
=
input_permute
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
*
N
,
N
,
G1
*
N
,
1
}
// Z layout [G0, M, G1, N]
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
M
*
N
,
M
*
N
,
N
,
1
};
// Z layout [G0, G1, M, N]
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1Q
*
N
,
N
,
G1Q
*
N
,
1
}
// Z layout [G0, M, G1Q, N]
:
std
::
vector
<
ck
::
index_t
>
{
G1Q
*
M
*
N
,
M
*
N
,
N
,
1
};
// Z layout [G0, G1Q, M, N]
std
::
vector
<
ck
::
index_t
>
kgrad_gs_ns_ks_lengths
{
G0
,
G1Q
,
N
,
K
};
std
::
vector
<
ck
::
index_t
>
kgrad_gs_ns_ks_strides
=
input_permute
?
std
::
vector
<
ck
::
index_t
>
{
N
*
G1Q
*
K
,
K
,
G1Q
*
K
,
1
}
// KGrad layout [G0, N, G1Q, K]
:
std
::
vector
<
ck
::
index_t
>
{
G1Q
*
N
*
K
,
N
*
K
,
K
,
1
};
// KGrad layout [G0, G1Q, N, K]
std
::
vector
<
ck
::
index_t
>
vgrad_gs_os_ns_lengths
{
G0
,
G1Q
,
O
,
N
};
std
::
vector
<
ck
::
index_t
>
vgrad_gs_os_ns_strides
=
input_permute
?
std
::
vector
<
ck
::
index_t
>
{
N
*
G1Q
*
O
,
O
,
1
,
G1Q
*
O
}
// VGrad layout [G0, N, G1Q, O]
:
std
::
vector
<
ck
::
index_t
>
{
G1Q
*
N
*
O
,
N
*
O
,
1
,
O
};
// VGrad layout [G0, G1Q, N, O]
// The softmax stat log-sum-exp (LSE) is used to speed up softmax calculation in backward
// pass Pi = exp(Si) / sum(exp(S0) + exp(S1) + ...)
// = exp(Si) / exp(log(sum(exp() + ...)))
// = exp(Si - log(sum(exp() + ...)))
// ^^^^^^^^^^^^^^^^^^^^^
// LSE
std
::
vector
<
ck
::
index_t
>
lse_gs_ms_lengths
{
G0
,
G1
,
M
};
std
::
vector
<
ck
::
index_t
>
lse_gs_ms_strides
{
G1
*
M
,
M
,
1
};
// LSE layout [G0, G1, M]
std
::
vector
<
ck
::
index_t
>
lse_gs_ms_lengths
{
G0
,
G1
Q
,
M
};
std
::
vector
<
ck
::
index_t
>
lse_gs_ms_strides
{
G1
Q
*
M
,
M
,
1
};
// LSE layout [G0, G1
Q
, M]
problem_descs
.
push_back
({
q_gs_ms_ks_lengths
,
q_gs_ms_ks_strides
,
...
...
@@ -423,13 +441,17 @@ int run(int argc, char* argv[])
y_gs_ms_os_strides
,
lse_gs_ms_lengths
,
lse_gs_ms_strides
,
kgrad_gs_ns_ks_lengths
,
kgrad_gs_ns_ks_strides
,
vgrad_gs_os_ns_lengths
,
vgrad_gs_os_ns_strides
,
{},
// std::array<std::vector<ck::index_t>, 1>{acc0_biases_gs_ms_ns_lengths},
{},
// std::array<std::vector<ck::index_t>, 1>{acc0_biases_gs_ms_ns_strides},
{},
// std::array<std::vector<ck::index_t>, 1>{acc1_biases_gs_ms_os_lengths},
{},
// std::array<std::vector<ck::index_t>, 1>{acc1_biases_gs_ms_os_strides},
});
int
BatchCount
=
G0
*
G1
;
int
BatchCount
=
G0
*
G1
Q
;
flop
+=
(
size_t
(
3
)
*
M
*
N
*
K
+
size_t
(
2
)
*
M
*
N
*
O
)
*
2
*
BatchCount
;
// Q/K/V/Y, dQ/dK/dV/dY, LSE
num_byte
+=
(
sizeof
(
InputDataType
)
*
M
*
K
+
sizeof
(
InputDataType
)
*
K
*
N
+
...
...
@@ -446,6 +468,8 @@ int run(int argc, char* argv[])
Tensor
<
InputDataType
>
y_gs_ms_os
(
y_gs_ms_os_lengths
,
y_gs_ms_os_strides
);
Tensor
<
InputDataType
>
ygrad_gs_ms_os
(
y_gs_ms_os_lengths
,
y_gs_ms_os_strides
);
Tensor
<
LSEDataType
>
lse_gs_ms
(
lse_gs_ms_lengths
,
lse_gs_ms_strides
);
Tensor
<
OutputDataType
>
kgrad_gs_ns_ks
(
kgrad_gs_ns_ks_lengths
,
kgrad_gs_ns_ks_strides
);
Tensor
<
OutputDataType
>
vgrad_gs_os_ns
(
vgrad_gs_os_ns_lengths
,
vgrad_gs_os_ns_strides
);
if
(
i
<
4
)
{
std
::
cout
<<
"q_gs_ms_ks: "
<<
q_gs_ms_ks
.
mDesc
<<
std
::
endl
;
...
...
@@ -454,6 +478,8 @@ int run(int argc, char* argv[])
std
::
cout
<<
"v_gs_os_ns: "
<<
v_gs_os_ns
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"y_gs_ms_os: "
<<
y_gs_ms_os
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"lse_gs_ms_os: "
<<
lse_gs_ms
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"kgrad_gs_ns_ks: "
<<
kgrad_gs_ns_ks
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"vgrad_gs_os_ns: "
<<
vgrad_gs_os_ns
.
mDesc
<<
std
::
endl
;
}
z_gs_ms_ns
.
GenerateTensorValue
(
GeneratorTensor_1
<
InputDataType
>
{
0
});
switch
(
init_method
)
...
...
@@ -487,14 +513,16 @@ int run(int argc, char* argv[])
q_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_1
<
InputDataType
>
{
1
});
k_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
InputDataType
>
{});
v_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
InputDataType
>
{});
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_Sequential
<
2
>
{});
// dy[g0, g1, m, o]
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_Sequential
<
2
>
{});
// dy[g0, g1q, m, o]
// dO dot O = [0; 1; 2; ...]
break
;
case
6
:
q_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_1
<
InputDataType
>
{
1
});
k_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
InputDataType
>
{});
v_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
InputDataType
>
{});
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_Sequential
<
3
>
{});
// dy[g0, g1, m, o]
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_Sequential
<
3
>
{});
// dy[g0, g1q, m, o]
// assume mnko = 256
// P = softmax(QK) = 0.0039 * ones
// O = P V = 0.0039 * ones
...
...
@@ -508,7 +536,7 @@ int run(int argc, char* argv[])
k_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
InputDataType
>
{});
v_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
InputDataType
>
{});
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_1
<
InputDataType
>
{
1
});
// dy[g0, g1, m, o]
GeneratorTensor_1
<
InputDataType
>
{
1
});
// dy[g0, g1
q
, m, o]
// assume mnko = 256
// P = softmax(QK) = 0.0039 * ones
// O = P V = 0.0039 * ones
...
...
@@ -529,13 +557,21 @@ int run(int argc, char* argv[])
Tensor
<
InputDataType
>
p_drop_g_m_n
({
BatchCount
,
M
,
N
});
q_gs_ms_ks
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
q_g_m_k
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
q_g_m_k
(
idx
[
0
]
*
G1
Q
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
});
k_gs_ns_ks
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
k_g_n_k
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
k_g_n_k
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
const
size_t
&
g0
=
idx
[
0
]
/
G1Q
;
const
size_t
&
g1q
=
idx
[
0
]
%
G1Q
;
const
size_t
&
g1kv
=
g1q
/
h_ratio
;
self
(
idx
)
=
k_gs_ns_ks
(
g0
,
g1kv
,
idx
[
1
],
idx
[
2
]);
});
v_gs_os_ns
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
v_g_n_o
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
3
],
idx
[
2
])
=
self
(
idx
);
v_g_n_o
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
const
size_t
&
g0
=
idx
[
0
]
/
G1Q
;
const
size_t
&
g1q
=
idx
[
0
]
%
G1Q
;
const
size_t
&
g1kv
=
g1q
/
h_ratio
;
self
(
idx
)
=
v_gs_os_ns
(
g0
,
g1kv
,
idx
[
2
],
idx
[
1
]);
});
q_g_m_ks
.
push_back
(
q_g_m_k
);
...
...
@@ -554,6 +590,8 @@ int run(int argc, char* argv[])
z_tensors
.
push_back
(
z_gs_ms_ns
);
lse_tensors
.
push_back
(
lse_gs_ms
);
ygrad_tensors
.
push_back
(
ygrad_gs_ms_os
);
kgrad_tensors
.
push_back
(
kgrad_gs_ns_ks
);
vgrad_tensors
.
push_back
(
vgrad_gs_os_ns
);
q_tensors_device
.
emplace_back
(
std
::
make_unique
<
DeviceMem
>
(
sizeof
(
InputDataType
)
*
q_gs_ms_ks
.
GetElementSpaceSize
()));
k_tensors_device
.
emplace_back
(
...
...
@@ -568,10 +606,10 @@ int run(int argc, char* argv[])
std
::
make_unique
<
DeviceMem
>
(
sizeof
(
LSEDataType
)
*
lse_gs_ms
.
GetElementSpaceSize
()));
qgrad_tensors_device
.
emplace_back
(
std
::
make_unique
<
DeviceMem
>
(
sizeof
(
OutputDataType
)
*
q_gs_ms_ks
.
GetElementSpaceSize
()));
kgrad_tensors_device
.
emplace_back
(
std
::
make_unique
<
DeviceMem
>
(
sizeof
(
OutputDataType
)
*
k_gs_ns_ks
.
GetElementSpaceSize
()));
vgrad_tensors_device
.
emplace_back
(
std
::
make_unique
<
DeviceMem
>
(
sizeof
(
OutputDataType
)
*
v_gs_os_ns
.
GetElementSpaceSize
()));
kgrad_tensors_device
.
emplace_back
(
std
::
make_unique
<
DeviceMem
>
(
sizeof
(
OutputDataType
)
*
k
grad
_gs_ns_ks
.
GetElementSpaceSize
()));
vgrad_tensors_device
.
emplace_back
(
std
::
make_unique
<
DeviceMem
>
(
sizeof
(
OutputDataType
)
*
v
grad
_gs_os_ns
.
GetElementSpaceSize
()));
ygrad_tensors_device
.
emplace_back
(
std
::
make_unique
<
DeviceMem
>
(
sizeof
(
InputDataType
)
*
y_gs_ms_os
.
GetElementSpaceSize
()));
q_tensors_device
.
back
()
->
ToDevice
(
q_gs_ms_ks
.
data
());
...
...
@@ -674,11 +712,11 @@ int run(int argc, char* argv[])
for
(
std
::
size_t
i
=
0
;
i
<
group_count
;
i
++
)
{
int
G1
=
v
_tensors
[
i
].
GetLengths
()[
1
];
int
G1
Q
=
q
_tensors
[
i
].
GetLengths
()[
1
];
// copy z matirx data form device
z_tensors_device
[
i
]
->
FromDevice
(
z_tensors
[
i
].
mData
.
data
());
z_tensors
[
i
].
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
z_g_m_ns
[
i
](
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
z_g_m_ns
[
i
](
idx
[
0
]
*
G1
Q
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
});
run_attention_fwd_host
(
q_g_m_ks
[
i
],
k_g_n_ks
[
i
],
...
...
@@ -694,11 +732,11 @@ int run(int argc, char* argv[])
rp_dropout
);
y_tensors
[
i
].
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
self
(
idx
)
=
y_g_m_os
[
i
](
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
]);
self
(
idx
)
=
y_g_m_os
[
i
](
idx
[
0
]
*
G1
Q
+
idx
[
1
],
idx
[
2
],
idx
[
3
]);
});
y_tensors_device
[
i
]
->
ToDevice
(
y_tensors
[
i
].
data
());
lse_tensors
[
i
].
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
self
(
idx
)
=
lse_g_ms
[
i
](
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
]);
self
(
idx
)
=
lse_g_ms
[
i
](
idx
[
0
]
*
G1
Q
+
idx
[
1
],
idx
[
2
]);
});
lse_tensors_device
[
i
]
->
ToDevice
(
lse_tensors
[
i
].
data
());
qgrad_tensors_device
[
i
]
->
SetZero
();
...
...
@@ -711,13 +749,13 @@ int run(int argc, char* argv[])
for
(
std
::
size_t
i
=
0
;
i
<
group_count
;
i
++
)
{
int
G0
=
v
_tensors
[
i
].
GetLengths
()[
0
];
int
G1
=
v
_tensors
[
i
].
GetLengths
()[
1
];
int
G0
=
q
_tensors
[
i
].
GetLengths
()[
0
];
int
G1
Q
=
q
_tensors
[
i
].
GetLengths
()[
1
];
int
O
=
v_tensors
[
i
].
GetLengths
()[
2
];
int
N
=
v_tensors
[
i
].
GetLengths
()[
3
];
int
M
=
q_tensors
[
i
].
GetLengths
()[
2
];
int
K
=
q_tensors
[
i
].
GetLengths
()[
3
];
int
BatchCount
=
G0
*
G1
;
int
BatchCount
=
G0
*
G1
Q
;
Tensor
<
OutputDataType
>
qgrad_g_m_k
({
BatchCount
,
M
,
K
});
Tensor
<
OutputDataType
>
kgrad_g_n_k
({
BatchCount
,
N
,
K
});
Tensor
<
OutputDataType
>
vgrad_g_n_o
({
BatchCount
,
N
,
O
});
...
...
@@ -727,7 +765,7 @@ int run(int argc, char* argv[])
Tensor
<
InputDataType
>
ygrad_g_m_o
({
BatchCount
,
M
,
O
});
ygrad_tensors
[
i
].
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
ygrad_g_m_o
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
ygrad_g_m_o
(
idx
[
0
]
*
G1
Q
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
});
auto
ref_gemm0_grad
=
ReferenceGemm0GradInstance
{};
auto
ref_gemm0_grad_invoker
=
ref_gemm0_grad
.
MakeInvoker
();
...
...
@@ -770,43 +808,43 @@ int run(int argc, char* argv[])
Tensor
<
OutputDataType
>
qgrad_gs_ms_ks_host_result
(
q_tensors
[
i
].
GetLengths
(),
q_tensors
[
i
].
GetStrides
());
Tensor
<
OutputDataType
>
kgrad_gs_ns_ks_host_result
(
k_tensors
[
i
].
GetLengths
(),
k_tensors
[
i
].
GetStrides
());
Tensor
<
OutputDataType
>
vgrad_gs_os_ns_host_result
(
v_tensors
[
i
].
GetLengths
(),
v_tensors
[
i
].
GetStrides
());
Tensor
<
OutputDataType
>
kgrad_gs_ns_ks_host_result
(
k
grad
_tensors
[
i
].
GetLengths
(),
k
grad
_tensors
[
i
].
GetStrides
());
Tensor
<
OutputDataType
>
vgrad_gs_os_ns_host_result
(
v
grad
_tensors
[
i
].
GetLengths
(),
v
grad
_tensors
[
i
].
GetStrides
());
Tensor
<
OutputDataType
>
qgrad_gs_ms_ks_device_result
(
q_tensors
[
i
].
GetLengths
(),
q_tensors
[
i
].
GetStrides
());
Tensor
<
OutputDataType
>
kgrad_gs_ns_ks_device_result
(
k_tensors
[
i
].
GetLengths
(),
k_tensors
[
i
].
GetStrides
());
Tensor
<
OutputDataType
>
vgrad_gs_os_ns_device_result
(
v_tensors
[
i
].
GetLengths
(),
v_tensors
[
i
].
GetStrides
());
Tensor
<
OutputDataType
>
kgrad_gs_ns_ks_device_result
(
k
grad
_tensors
[
i
].
GetLengths
(),
k
grad
_tensors
[
i
].
GetStrides
());
Tensor
<
OutputDataType
>
vgrad_gs_os_ns_device_result
(
v
grad
_tensors
[
i
].
GetLengths
(),
v
grad
_tensors
[
i
].
GetStrides
());
qgrad_tensors_device
[
i
]
->
FromDevice
(
qgrad_gs_ms_ks_device_result
.
data
());
kgrad_tensors_device
[
i
]
->
FromDevice
(
kgrad_gs_ns_ks_device_result
.
data
());
vgrad_tensors_device
[
i
]
->
FromDevice
(
vgrad_gs_os_ns_device_result
.
data
());
// permute
qgrad_gs_ms_ks_host_result
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
const
size_t
&
g0
=
idx
[
0
];
const
size_t
&
g1
=
idx
[
1
];
const
size_t
&
g0
=
idx
[
0
];
const
size_t
&
g1
q
=
idx
[
1
];
const
size_t
g
=
g0
*
G1
+
g1
;
const
size_t
g
=
g0
*
G1
Q
+
g1
q
;
self
(
idx
)
=
qgrad_g_m_k
(
g
,
idx
[
2
],
idx
[
3
]);
});
kgrad_gs_ns_ks_host_result
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
const
size_t
&
g0
=
idx
[
0
];
const
size_t
&
g1
=
idx
[
1
];
const
size_t
&
g0
=
idx
[
0
];
const
size_t
&
g1
q
=
idx
[
1
];
const
size_t
g
=
g0
*
G1
+
g1
;
const
size_t
g
=
g0
*
G1
Q
+
g1
q
;
self
(
idx
)
=
kgrad_g_n_k
(
g
,
idx
[
2
],
idx
[
3
]);
});
vgrad_gs_os_ns_host_result
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
const
size_t
&
g0
=
idx
[
0
];
const
size_t
&
g1
=
idx
[
1
];
const
size_t
&
g0
=
idx
[
0
];
const
size_t
&
g1
q
=
idx
[
1
];
const
size_t
g
=
g0
*
G1
+
g1
;
const
size_t
g
=
g0
*
G1
Q
+
g1
q
;
self
(
idx
)
=
vgrad_g_n_o
(
g
,
idx
[
3
],
idx
[
2
]);
});
...
...
example/32_batched_gemm_scale_softmax_gemm/grouped_multihead_attention_backward_v3.cpp
View file @
e87ddb0e
...
...
@@ -24,7 +24,7 @@ Kernel outputs:
*/
#define USING_MASK 0
#define DIM
32
// DIM should be a multiple of 8.
#define DIM
128
// DIM should be a multiple of 8.
#include <iostream>
#include <numeric>
...
...
@@ -269,10 +269,11 @@ int run(int argc, char* argv[])
// Overall QKV matrices shape
// y_g_m_o = Softmax(alpha * Q_g_m_k * K_g_k_n) * V_g_n_o
// y_g0_g1_m_o = reshape(y_g_m_o, [G0, G1, M, O])
// y_g0_m_g1_o = permute(y_g0_g1_m_o, [0, 2, 1, 3])
// y_g0_g1
q
_m_o = reshape(y_g_m_o, [G0, G1
Q
, M, O])
// y_g0_m_g1
q
_o = permute(y_g0_g1
q
_m_o, [0, 2, 1, 3])
float
alpha
=
1.
f
/
std
::
sqrt
(
DIM
);
float
p_drop
=
0.0
;
int
h_ratio
=
1
;
// G1Q / G1KV
bool
input_permute
=
true
;
bool
output_permute
=
true
;
...
...
@@ -290,25 +291,26 @@ int run(int argc, char* argv[])
init_method
=
std
::
stoi
(
argv
[
2
]);
time_kernel
=
std
::
stoi
(
argv
[
3
]);
}
else
if
(
argc
==
7
)
else
if
(
argc
==
8
)
{
do_verification
=
std
::
stoi
(
argv
[
1
]);
init_method
=
std
::
stoi
(
argv
[
2
]);
time_kernel
=
std
::
stoi
(
argv
[
3
]);
p_drop
=
std
::
stof
(
argv
[
4
]);
p_drop
=
std
::
stof
(
argv
[
4
]);
h_ratio
=
std
::
stof
(
argv
[
5
]);
input_permute
=
std
::
stoi
(
argv
[
5
]);
output_permute
=
std
::
stoi
(
argv
[
6
]);
input_permute
=
std
::
stoi
(
argv
[
6
]);
output_permute
=
std
::
stoi
(
argv
[
7
]);
}
else
{
printf
(
"arg1: verification (0=no, 1=yes)
\n
"
);
printf
(
"arg2: initialization (0=no init, 1=integer value, 2=decimal value)
\n
"
);
printf
(
"arg3: time kernel (0=no, 1=yes)
\n
"
);
printf
(
"arg4
to 11: M, N, K, O, G0, G1
\n
"
);
printf
(
"arg
10: scale (alpha)
\n
"
);
printf
(
"arg
11
to
12
: input / output permute
\n
"
);
printf
(
"arg4
: p_drop
\n
"
);
printf
(
"arg
5: h_ratio
\n
"
);
printf
(
"arg
6
to
7
: input / output permute
\n
"
);
exit
(
0
);
}
...
...
@@ -371,49 +373,65 @@ int run(int argc, char* argv[])
std
::
size_t
flop
=
0
,
num_byte
=
0
;
for
(
std
::
size_t
i
=
0
;
i
<
group_count
;
i
++
)
{
int
M
=
128
*
(
rand
()
%
8
)
+
(
rand
()
%
128
);
int
N
=
128
*
(
rand
()
%
8
)
+
(
rand
()
%
128
);
int
K
=
DIM
;
int
O
=
DIM
;
int
G0
=
rand
()
%
4
+
1
;
int
G1
=
rand
()
%
4
+
1
;
std
::
vector
<
ck
::
index_t
>
q_gs_ms_ks_lengths
{
G0
,
G1
,
M
,
K
};
int
M
=
128
*
(
rand
()
%
8
)
+
(
rand
()
%
128
);
int
N
=
128
*
(
rand
()
%
8
)
+
(
rand
()
%
128
);
int
K
=
DIM
;
int
O
=
DIM
;
int
G0
=
rand
()
%
4
+
1
;
int
G1KV
=
rand
()
%
4
+
1
;
int
G1Q
=
G1KV
*
h_ratio
;
std
::
vector
<
ck
::
index_t
>
q_gs_ms_ks_lengths
{
G0
,
G1Q
,
M
,
K
};
std
::
vector
<
ck
::
index_t
>
q_gs_ms_ks_strides
=
input_permute
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
*
K
,
K
,
G1
*
K
,
1
}
// Q layout [G0, M, G1, K]
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
M
*
K
,
M
*
K
,
K
,
1
};
// Q layout [G0, G1, M, K]
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
Q
*
K
,
K
,
G1
Q
*
K
,
1
}
// Q layout [G0, M, G1
Q
, K]
:
std
::
vector
<
ck
::
index_t
>
{
G1
Q
*
M
*
K
,
M
*
K
,
K
,
1
};
// Q layout [G0, G1
Q
, M, K]
std
::
vector
<
ck
::
index_t
>
k_gs_ns_ks_lengths
{
G0
,
G1
,
N
,
K
};
std
::
vector
<
ck
::
index_t
>
k_gs_ns_ks_lengths
{
G0
,
G1
KV
,
N
,
K
};
std
::
vector
<
ck
::
index_t
>
k_gs_ns_ks_strides
=
input_permute
?
std
::
vector
<
ck
::
index_t
>
{
N
*
G1
*
K
,
K
,
G1
*
K
,
1
}
// K layout [G0, N, G1, K]
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
N
*
K
,
N
*
K
,
K
,
1
};
// K layout [G0, G1, N, K]
?
std
::
vector
<
ck
::
index_t
>
{
N
*
G1KV
*
K
,
K
,
G1KV
*
K
,
1
}
// K layout [G0, N, G1KV, K]
:
std
::
vector
<
ck
::
index_t
>
{
G1KV
*
N
*
K
,
N
*
K
,
K
,
1
};
// K layout [G0, G1KV, N, K]
std
::
vector
<
ck
::
index_t
>
v_gs_os_ns_lengths
{
G0
,
G1
,
O
,
N
};
std
::
vector
<
ck
::
index_t
>
v_gs_os_ns_lengths
{
G0
,
G1
KV
,
O
,
N
};
std
::
vector
<
ck
::
index_t
>
v_gs_os_ns_strides
=
input_permute
?
std
::
vector
<
ck
::
index_t
>
{
N
*
G1
*
O
,
O
,
1
,
G1
*
O
}
// V layout [G0, N, G1, O]
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
N
*
O
,
N
*
O
,
1
,
O
};
// V layout [G0, G1, N, O]
?
std
::
vector
<
ck
::
index_t
>
{
N
*
G1KV
*
O
,
O
,
1
,
G1KV
*
O
}
// V layout [G0, N, G1KV, O]
:
std
::
vector
<
ck
::
index_t
>
{
G1KV
*
N
*
O
,
N
*
O
,
1
,
O
};
// V layout [G0, G1KV, N, O]
std
::
vector
<
ck
::
index_t
>
y_gs_ms_os_lengths
{
G0
,
G1
,
M
,
O
};
std
::
vector
<
ck
::
index_t
>
y_gs_ms_os_lengths
{
G0
,
G1
Q
,
M
,
O
};
std
::
vector
<
ck
::
index_t
>
y_gs_ms_os_strides
=
output_permute
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
*
O
,
O
,
G1
*
O
,
1
}
// Y layout [G0, M, G1, O]
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
M
*
O
,
M
*
O
,
O
,
1
};
// Y layout [G0, G1, M, O]
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
Q
*
O
,
O
,
G1
Q
*
O
,
1
}
// Y layout [G0, M, G1
Q
, O]
:
std
::
vector
<
ck
::
index_t
>
{
G1
Q
*
M
*
O
,
M
*
O
,
O
,
1
};
// Y layout [G0, G1
Q
, M, O]
std
::
vector
<
ck
::
index_t
>
z_gs_ms_ns_lengths
{
G0
,
G1
,
M
,
N
};
std
::
vector
<
ck
::
index_t
>
z_gs_ms_ns_lengths
{
G0
,
G1
Q
,
M
,
N
};
std
::
vector
<
ck
::
index_t
>
z_gs_ms_ns_strides
=
input_permute
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
*
N
,
N
,
G1
*
N
,
1
}
// Z layout [G0, M, G1, N]
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
M
*
N
,
M
*
N
,
N
,
1
};
// Z layout [G0, G1, M, N]
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1Q
*
N
,
N
,
G1Q
*
N
,
1
}
// Z layout [G0, M, G1Q, N]
:
std
::
vector
<
ck
::
index_t
>
{
G1Q
*
M
*
N
,
M
*
N
,
N
,
1
};
// Z layout [G0, G1Q, M, N]
std
::
vector
<
ck
::
index_t
>
kgrad_gs_ns_ks_lengths
{
G0
,
G1Q
,
N
,
K
};
std
::
vector
<
ck
::
index_t
>
kgrad_gs_ns_ks_strides
=
input_permute
?
std
::
vector
<
ck
::
index_t
>
{
N
*
G1Q
*
K
,
K
,
G1Q
*
K
,
1
}
// KGrad layout [G0, N, G1Q, K]
:
std
::
vector
<
ck
::
index_t
>
{
G1Q
*
N
*
K
,
N
*
K
,
K
,
1
};
// KGrad layout [G0, G1Q, N, K]
std
::
vector
<
ck
::
index_t
>
vgrad_gs_os_ns_lengths
{
G0
,
G1Q
,
O
,
N
};
std
::
vector
<
ck
::
index_t
>
vgrad_gs_os_ns_strides
=
input_permute
?
std
::
vector
<
ck
::
index_t
>
{
N
*
G1Q
*
O
,
O
,
1
,
G1Q
*
O
}
// VGrad layout [G0, N, G1Q, O]
:
std
::
vector
<
ck
::
index_t
>
{
G1Q
*
N
*
O
,
N
*
O
,
1
,
O
};
// VGrad layout [G0, G1Q, N, O]
// The softmax stat log-sum-exp (LSE) is used to speed up softmax calculation in backward
// pass Pi = exp(Si) / sum(exp(S0) + exp(S1) + ...)
// = exp(Si) / exp(log(sum(exp() + ...)))
// = exp(Si - log(sum(exp() + ...)))
// ^^^^^^^^^^^^^^^^^^^^^
// LSE
std
::
vector
<
ck
::
index_t
>
lse_gs_ms_lengths
{
G0
,
G1
,
M
};
std
::
vector
<
ck
::
index_t
>
lse_gs_ms_strides
{
G1
*
M
,
M
,
1
};
// LSE layout [G0, G1, M]
std
::
vector
<
ck
::
index_t
>
lse_gs_ms_lengths
{
G0
,
G1
Q
,
M
};
std
::
vector
<
ck
::
index_t
>
lse_gs_ms_strides
{
G1
Q
*
M
,
M
,
1
};
// LSE layout [G0, G1
Q
, M]
problem_descs
.
push_back
({
q_gs_ms_ks_lengths
,
q_gs_ms_ks_strides
,
...
...
@@ -427,13 +445,17 @@ int run(int argc, char* argv[])
y_gs_ms_os_strides
,
lse_gs_ms_lengths
,
lse_gs_ms_strides
,
kgrad_gs_ns_ks_lengths
,
kgrad_gs_ns_ks_strides
,
vgrad_gs_os_ns_lengths
,
vgrad_gs_os_ns_strides
,
{},
// std::array<std::vector<ck::index_t>, 1>{acc0_biases_gs_ms_ns_lengths},
{},
// std::array<std::vector<ck::index_t>, 1>{acc0_biases_gs_ms_ns_strides},
{},
// std::array<std::vector<ck::index_t>, 1>{acc1_biases_gs_ms_os_lengths},
{},
// std::array<std::vector<ck::index_t>, 1>{acc1_biases_gs_ms_os_strides},
});
int
BatchCount
=
G0
*
G1
;
int
BatchCount
=
G0
*
G1
Q
;
flop
+=
(
size_t
(
3
)
*
M
*
N
*
K
+
size_t
(
2
)
*
M
*
N
*
O
)
*
2
*
BatchCount
;
// Q/K/V/Y, dQ/dK/dV/dY, LSE
num_byte
+=
(
sizeof
(
InputDataType
)
*
M
*
K
+
sizeof
(
InputDataType
)
*
K
*
N
+
...
...
@@ -451,6 +473,8 @@ int run(int argc, char* argv[])
Tensor
<
InputDataType
>
ygrad_gs_ms_os
(
y_gs_ms_os_lengths
,
y_gs_ms_os_strides
);
Tensor
<
LSEDataType
>
lse_gs_ms
(
lse_gs_ms_lengths
,
lse_gs_ms_strides
);
Tensor
<
DDataType
>
d_gs_ms
(
lse_gs_ms_lengths
,
lse_gs_ms_strides
);
Tensor
<
OutputDataType
>
kgrad_gs_ns_ks
(
kgrad_gs_ns_ks_lengths
,
kgrad_gs_ns_ks_strides
);
Tensor
<
OutputDataType
>
vgrad_gs_os_ns
(
vgrad_gs_os_ns_lengths
,
vgrad_gs_os_ns_strides
);
if
(
i
<
4
)
{
std
::
cout
<<
"q_gs_ms_ks: "
<<
q_gs_ms_ks
.
mDesc
<<
std
::
endl
;
...
...
@@ -460,6 +484,8 @@ int run(int argc, char* argv[])
std
::
cout
<<
"y_gs_ms_os: "
<<
y_gs_ms_os
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"lse_gs_ms_os: "
<<
lse_gs_ms
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"d_gs_ms_os: "
<<
d_gs_ms
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"kgrad_gs_ns_ks: "
<<
kgrad_gs_ns_ks
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"vgrad_gs_os_ns: "
<<
vgrad_gs_os_ns
.
mDesc
<<
std
::
endl
;
}
z_gs_ms_ns
.
GenerateTensorValue
(
GeneratorTensor_1
<
InputDataType
>
{
0
});
switch
(
init_method
)
...
...
@@ -493,14 +519,16 @@ int run(int argc, char* argv[])
q_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_1
<
InputDataType
>
{
1
});
k_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
InputDataType
>
{});
v_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
InputDataType
>
{});
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_Sequential
<
2
>
{});
// dy[g0, g1, m, o]
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_Sequential
<
2
>
{});
// dy[g0, g1q, m, o]
// dO dot O = [0; 1; 2; ...]
break
;
case
6
:
q_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_1
<
InputDataType
>
{
1
});
k_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
InputDataType
>
{});
v_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
InputDataType
>
{});
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_Sequential
<
3
>
{});
// dy[g0, g1, m, o]
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_Sequential
<
3
>
{});
// dy[g0, g1q, m, o]
// assume mnko = 256
// P = softmax(QK) = 0.0039 * ones
// O = P V = 0.0039 * ones
...
...
@@ -514,7 +542,7 @@ int run(int argc, char* argv[])
k_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
InputDataType
>
{});
v_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
InputDataType
>
{});
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_1
<
InputDataType
>
{
1
});
// dy[g0, g1, m, o]
GeneratorTensor_1
<
InputDataType
>
{
1
});
// dy[g0, g1
q
, m, o]
// assume mnko = 256
// P = softmax(QK) = 0.0039 * ones
// O = P V = 0.0039 * ones
...
...
@@ -536,13 +564,21 @@ int run(int argc, char* argv[])
Tensor
<
InputDataType
>
p_drop_g_m_n
({
BatchCount
,
M
,
N
});
q_gs_ms_ks
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
q_g_m_k
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
q_g_m_k
(
idx
[
0
]
*
G1
Q
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
});
k_gs_ns_ks
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
k_g_n_k
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
k_g_n_k
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
const
size_t
&
g0
=
idx
[
0
]
/
G1Q
;
const
size_t
&
g1q
=
idx
[
0
]
%
G1Q
;
const
size_t
&
g1kv
=
g1q
/
h_ratio
;
self
(
idx
)
=
k_gs_ns_ks
(
g0
,
g1kv
,
idx
[
1
],
idx
[
2
]);
});
v_gs_os_ns
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
v_g_n_o
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
3
],
idx
[
2
])
=
self
(
idx
);
v_g_n_o
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
const
size_t
&
g0
=
idx
[
0
]
/
G1Q
;
const
size_t
&
g1q
=
idx
[
0
]
%
G1Q
;
const
size_t
&
g1kv
=
g1q
/
h_ratio
;
self
(
idx
)
=
v_gs_os_ns
(
g0
,
g1kv
,
idx
[
2
],
idx
[
1
]);
});
q_g_m_ks
.
push_back
(
q_g_m_k
);
...
...
@@ -562,6 +598,8 @@ int run(int argc, char* argv[])
z_tensors
.
push_back
(
z_gs_ms_ns
);
lse_tensors
.
push_back
(
lse_gs_ms
);
ygrad_tensors
.
push_back
(
ygrad_gs_ms_os
);
kgrad_tensors
.
push_back
(
kgrad_gs_ns_ks
);
vgrad_tensors
.
push_back
(
vgrad_gs_os_ns
);
q_tensors_device
.
emplace_back
(
std
::
make_unique
<
DeviceMem
>
(
sizeof
(
InputDataType
)
*
q_gs_ms_ks
.
GetElementSpaceSize
()));
k_tensors_device
.
emplace_back
(
...
...
@@ -578,10 +616,10 @@ int run(int argc, char* argv[])
std
::
make_unique
<
DeviceMem
>
(
sizeof
(
DDataType
)
*
d_gs_ms
.
GetElementSpaceSize
()));
qgrad_tensors_device
.
emplace_back
(
std
::
make_unique
<
DeviceMem
>
(
sizeof
(
OutputDataType
)
*
q_gs_ms_ks
.
GetElementSpaceSize
()));
kgrad_tensors_device
.
emplace_back
(
std
::
make_unique
<
DeviceMem
>
(
sizeof
(
OutputDataType
)
*
k_gs_ns_ks
.
GetElementSpaceSize
()));
vgrad_tensors_device
.
emplace_back
(
std
::
make_unique
<
DeviceMem
>
(
sizeof
(
OutputDataType
)
*
v_gs_os_ns
.
GetElementSpaceSize
()));
kgrad_tensors_device
.
emplace_back
(
std
::
make_unique
<
DeviceMem
>
(
sizeof
(
OutputDataType
)
*
k
grad
_gs_ns_ks
.
GetElementSpaceSize
()));
vgrad_tensors_device
.
emplace_back
(
std
::
make_unique
<
DeviceMem
>
(
sizeof
(
OutputDataType
)
*
v
grad
_gs_os_ns
.
GetElementSpaceSize
()));
ygrad_tensors_device
.
emplace_back
(
std
::
make_unique
<
DeviceMem
>
(
sizeof
(
InputDataType
)
*
y_gs_ms_os
.
GetElementSpaceSize
()));
q_tensors_device
.
back
()
->
ToDevice
(
q_gs_ms_ks
.
data
());
...
...
@@ -687,11 +725,11 @@ int run(int argc, char* argv[])
for
(
std
::
size_t
i
=
0
;
i
<
group_count
;
i
++
)
{
int
G1
=
v
_tensors
[
i
].
GetLengths
()[
1
];
int
G1
Q
=
q
_tensors
[
i
].
GetLengths
()[
1
];
// copy z matirx data form device
z_tensors_device
[
i
]
->
FromDevice
(
z_tensors
[
i
].
mData
.
data
());
z_tensors
[
i
].
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
z_g_m_ns
[
i
](
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
z_g_m_ns
[
i
](
idx
[
0
]
*
G1
Q
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
});
run_attention_fwd_host
(
q_g_m_ks
[
i
],
k_g_n_ks
[
i
],
...
...
@@ -707,11 +745,11 @@ int run(int argc, char* argv[])
rp_dropout
);
y_tensors
[
i
].
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
self
(
idx
)
=
y_g_m_os
[
i
](
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
]);
self
(
idx
)
=
y_g_m_os
[
i
](
idx
[
0
]
*
G1
Q
+
idx
[
1
],
idx
[
2
],
idx
[
3
]);
});
y_tensors_device
[
i
]
->
ToDevice
(
y_tensors
[
i
].
data
());
lse_tensors
[
i
].
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
self
(
idx
)
=
lse_g_ms
[
i
](
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
]);
self
(
idx
)
=
lse_g_ms
[
i
](
idx
[
0
]
*
G1
Q
+
idx
[
1
],
idx
[
2
]);
});
lse_tensors_device
[
i
]
->
ToDevice
(
lse_tensors
[
i
].
data
());
qgrad_tensors_device
[
i
]
->
SetZero
();
...
...
@@ -724,13 +762,13 @@ int run(int argc, char* argv[])
for
(
std
::
size_t
i
=
0
;
i
<
group_count
;
i
++
)
{
int
G0
=
v
_tensors
[
i
].
GetLengths
()[
0
];
int
G1
=
v
_tensors
[
i
].
GetLengths
()[
1
];
int
G0
=
q
_tensors
[
i
].
GetLengths
()[
0
];
int
G1
Q
=
q
_tensors
[
i
].
GetLengths
()[
1
];
int
O
=
v_tensors
[
i
].
GetLengths
()[
2
];
int
N
=
v_tensors
[
i
].
GetLengths
()[
3
];
int
M
=
q_tensors
[
i
].
GetLengths
()[
2
];
int
K
=
q_tensors
[
i
].
GetLengths
()[
3
];
int
BatchCount
=
G0
*
G1
;
int
BatchCount
=
G0
*
G1
Q
;
Tensor
<
OutputDataType
>
qgrad_g_m_k
({
BatchCount
,
M
,
K
});
Tensor
<
OutputDataType
>
kgrad_g_n_k
({
BatchCount
,
N
,
K
});
Tensor
<
OutputDataType
>
vgrad_g_n_o
({
BatchCount
,
N
,
O
});
...
...
@@ -740,7 +778,7 @@ int run(int argc, char* argv[])
Tensor
<
InputDataType
>
ygrad_g_m_o
({
BatchCount
,
M
,
O
});
ygrad_tensors
[
i
].
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
ygrad_g_m_o
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
ygrad_g_m_o
(
idx
[
0
]
*
G1
Q
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
});
auto
ref_gemm0_grad
=
ReferenceGemm0GradInstance
{};
auto
ref_gemm0_grad_invoker
=
ref_gemm0_grad
.
MakeInvoker
();
...
...
@@ -783,43 +821,43 @@ int run(int argc, char* argv[])
Tensor
<
OutputDataType
>
qgrad_gs_ms_ks_host_result
(
q_tensors
[
i
].
GetLengths
(),
q_tensors
[
i
].
GetStrides
());
Tensor
<
OutputDataType
>
kgrad_gs_ns_ks_host_result
(
k_tensors
[
i
].
GetLengths
(),
k_tensors
[
i
].
GetStrides
());
Tensor
<
OutputDataType
>
vgrad_gs_os_ns_host_result
(
v_tensors
[
i
].
GetLengths
(),
v_tensors
[
i
].
GetStrides
());
Tensor
<
OutputDataType
>
kgrad_gs_ns_ks_host_result
(
k
grad
_tensors
[
i
].
GetLengths
(),
k
grad
_tensors
[
i
].
GetStrides
());
Tensor
<
OutputDataType
>
vgrad_gs_os_ns_host_result
(
v
grad
_tensors
[
i
].
GetLengths
(),
v
grad
_tensors
[
i
].
GetStrides
());
Tensor
<
OutputDataType
>
qgrad_gs_ms_ks_device_result
(
q_tensors
[
i
].
GetLengths
(),
q_tensors
[
i
].
GetStrides
());
Tensor
<
OutputDataType
>
kgrad_gs_ns_ks_device_result
(
k_tensors
[
i
].
GetLengths
(),
k_tensors
[
i
].
GetStrides
());
Tensor
<
OutputDataType
>
vgrad_gs_os_ns_device_result
(
v_tensors
[
i
].
GetLengths
(),
v_tensors
[
i
].
GetStrides
());
Tensor
<
OutputDataType
>
kgrad_gs_ns_ks_device_result
(
k
grad
_tensors
[
i
].
GetLengths
(),
k
grad
_tensors
[
i
].
GetStrides
());
Tensor
<
OutputDataType
>
vgrad_gs_os_ns_device_result
(
v
grad
_tensors
[
i
].
GetLengths
(),
v
grad
_tensors
[
i
].
GetStrides
());
qgrad_tensors_device
[
i
]
->
FromDevice
(
qgrad_gs_ms_ks_device_result
.
data
());
kgrad_tensors_device
[
i
]
->
FromDevice
(
kgrad_gs_ns_ks_device_result
.
data
());
vgrad_tensors_device
[
i
]
->
FromDevice
(
vgrad_gs_os_ns_device_result
.
data
());
// permute
qgrad_gs_ms_ks_host_result
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
const
size_t
&
g0
=
idx
[
0
];
const
size_t
&
g1
=
idx
[
1
];
const
size_t
&
g0
=
idx
[
0
];
const
size_t
&
g1
q
=
idx
[
1
];
const
size_t
g
=
g0
*
G1
+
g1
;
const
size_t
g
=
g0
*
G1
Q
+
g1
q
;
self
(
idx
)
=
qgrad_g_m_k
(
g
,
idx
[
2
],
idx
[
3
]);
});
kgrad_gs_ns_ks_host_result
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
const
size_t
&
g0
=
idx
[
0
];
const
size_t
&
g1
=
idx
[
1
];
const
size_t
&
g0
=
idx
[
0
];
const
size_t
&
g1
q
=
idx
[
1
];
const
size_t
g
=
g0
*
G1
+
g1
;
const
size_t
g
=
g0
*
G1
Q
+
g1
q
;
self
(
idx
)
=
kgrad_g_n_k
(
g
,
idx
[
2
],
idx
[
3
]);
});
vgrad_gs_os_ns_host_result
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
const
size_t
&
g0
=
idx
[
0
];
const
size_t
&
g1
=
idx
[
1
];
const
size_t
&
g0
=
idx
[
0
];
const
size_t
&
g1
q
=
idx
[
1
];
const
size_t
g
=
g0
*
G1
+
g1
;
const
size_t
g
=
g0
*
G1
Q
+
g1
q
;
self
(
idx
)
=
vgrad_g_n_o
(
g
,
idx
[
3
],
idx
[
2
]);
});
...
...
example/32_batched_gemm_scale_softmax_gemm/grouped_multihead_attention_forward_v2.cpp
View file @
e87ddb0e
...
...
@@ -71,11 +71,10 @@ static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecializatio
static
constexpr
auto
MaskingSpec
=
ck
::
tensor_operation
::
device
::
MaskingSpecialization
::
MaskDisabled
;
static
constexpr
auto
TensorSpecA
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecB0
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecB1
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecC
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
bool
Deterministic
=
true
;
static
constexpr
auto
TensorSpecA
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecB0
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecB1
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecC
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
#if(DIM <= 32)
using
DeviceGemmInstance
=
...
...
@@ -149,8 +148,7 @@ using DeviceGemmInstance =
S
<
1
,
64
,
1
,
4
>
,
// CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8
,
// CShuffleBlockTransferScalarPerVector_NPerBlock
1
,
MaskingSpec
,
// MaskingSpecialization
Deterministic
>
;
MaskingSpec
>
;
// MaskingSpecialization
#elif(DIM <= 64)
using
DeviceGemmInstance
=
ck
::
tensor_operation
::
device
::
DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
<
...
...
@@ -223,8 +221,7 @@ using DeviceGemmInstance =
S
<
1
,
32
,
1
,
8
>
,
// CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8
,
// CShuffleBlockTransferScalarPerVector_NPerBlock
1
,
MaskingSpec
,
// MaskingSpecialization
Deterministic
>
;
MaskingSpec
>
;
// MaskingSpecialization
#elif(DIM <= 128)
using
DeviceGemmInstance
=
ck
::
tensor_operation
::
device
::
DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
<
...
...
@@ -297,8 +294,7 @@ using DeviceGemmInstance =
S
<
1
,
32
,
1
,
8
>
,
// CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8
,
// CShuffleBlockTransferScalarPerVector_NPerBlock
1
,
MaskingSpec
,
// MaskingSpecialization
Deterministic
>
;
MaskingSpec
>
;
// MaskingSpecialization
#endif
// Ref Gemm0: DataType in, AccDataType out
...
...
example/32_batched_gemm_scale_softmax_gemm/grouped_multihead_attention_train_v2.cpp
View file @
e87ddb0e
...
...
@@ -112,11 +112,11 @@ static constexpr bool Deterministic = false;
#if(DIM <= 32)
// clang-format off
using
DeviceGemmInstanceFWD
=
// #################################################################################| NumDimG| NumDimM| NumDimN| NumDimK| NumDimO| ADataType| BDataType| B1DataType| CDataType| GemmDataType| ZDataType| LSEDataType| Acc0BiasDataType| Acc1BiasDataType| GemmAcc| CShuffle| A| B| Acc| B1| C| GEMM| ATensorSpec| B0TensorSpec| B1TensorSpec| CTensorSpec| NumGemmK| Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1|Dropout| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| D0BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CShuffleBlockTransferScalarPerVector| D1BlockTransfer| MaskingSpec|
Deterministic|
// #################################################################################| | | | | | | | | | | | | | | DataType| DataType| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| | | | | Prefetch| Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| Step| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcScalar| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| _NPerBlock| SrcScalar| |
|
// #################################################################################| | | | | | | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | | | | Stage| | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| | PerVector| |
|
// #################################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
ck
::
tensor_operation
::
device
::
DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
<
NumDimG
,
NumDimM
,
NumDimN
,
NumDimK
,
NumDimO
,
InputDataType
,
InputDataType
,
InputDataType
,
InputDataType
,
GemmDataType
,
ZDataType
,
LSEDataType
,
void
,
void
,
AccDataType
,
ShuffleDataType
,
QKVElementOp
,
QKVElementOp
,
Scale
,
QKVElementOp
,
YElementOp
,
GemmSpec
,
TensorSpecQ
,
TensorSpecK
,
TensorSpecV
,
TensorSpecY
,
1
,
256
,
128
,
128
,
32
,
32
,
32
,
8
,
8
,
2
,
32
,
32
,
1
,
4
,
1
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
S
<
16
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
2
,
false
,
1
,
1
,
S
<
1
,
64
,
1
,
4
>
,
8
,
1
,
MaskingSpec
,
Deterministic
>
;
// #################################################################################| NumDimG| NumDimM| NumDimN| NumDimK| NumDimO| ADataType| BDataType| B1DataType| CDataType| GemmDataType| ZDataType| LSEDataType| Acc0BiasDataType| Acc1BiasDataType| GemmAcc| CShuffle| A| B| Acc| B1| C| GEMM| ATensorSpec| B0TensorSpec| B1TensorSpec| CTensorSpec| NumGemmK| Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1|Dropout| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| D0BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CShuffleBlockTransferScalarPerVector| D1BlockTransfer| MaskingSpec|
// #################################################################################| | | | | | | | | | | | | | | DataType| DataType| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| | | | | Prefetch| Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| Step| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcScalar| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| _NPerBlock| SrcScalar| |
// #################################################################################| | | | | | | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | | | | Stage| | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| | PerVector| |
// #################################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
ck
::
tensor_operation
::
device
::
DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
<
NumDimG
,
NumDimM
,
NumDimN
,
NumDimK
,
NumDimO
,
InputDataType
,
InputDataType
,
InputDataType
,
InputDataType
,
GemmDataType
,
ZDataType
,
LSEDataType
,
void
,
void
,
AccDataType
,
ShuffleDataType
,
QKVElementOp
,
QKVElementOp
,
Scale
,
QKVElementOp
,
YElementOp
,
GemmSpec
,
TensorSpecQ
,
TensorSpecK
,
TensorSpecV
,
TensorSpecY
,
1
,
256
,
128
,
128
,
32
,
32
,
32
,
8
,
8
,
2
,
32
,
32
,
1
,
4
,
1
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
S
<
16
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
2
,
false
,
1
,
1
,
S
<
1
,
64
,
1
,
4
>
,
8
,
1
,
MaskingSpec
>
;
using
DeviceGemmInstanceBWD
=
// ########################################################################################| NumDimG| NumDimM| NumDimN| NumDimK| NumDimO| InputDataType| OutputDataType| GemmDataType| ZDataType| LSEDataType| Acc0BiasDataType| Acc1BiasDataType| GemmAcc| CShuffle| A| B| Acc| B1| C| GEMM| ATensorSpec| B0TensorSpec| B1TensorSpec| CTensorSpec| NumGemmK| Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| Gemm2| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1| Gemm2| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| D0BlockTransfer| CShuffle| CShuffle| CBlockTransferClusterLengths| CShuffleBlockTransferScalarPerVector_NPerBlock| MaskingSpec| Deterministic|
...
...
@@ -128,11 +128,11 @@ using DeviceGemmInstanceBWD =
#elif(DIM <= 64)
// clang-format off
using
DeviceGemmInstanceFWD
=
// #################################################################################| NumDimG| NumDimM| NumDimN| NumDimK| NumDimO| ADataType| BDataType| B1DataType| CDataType| GemmDataType| ZDataType| LSEDataType| Acc0BiasDataType| Acc1BiasDataType| GemmAcc| CShuffle| A| B| Acc| B1| C| GEMM| ATensorSpec| B0TensorSpec| B1TensorSpec| CTensorSpec| NumGemmK| Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1|Dropout| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| D0BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CShuffleBlockTransferScalarPerVector| D1BlockTransfer| MaskingSpec|
Deterministic|
// #################################################################################| | | | | | | | | | | | | | | DataType| DataType| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| | | | | Prefetch| Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| Step| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcScalar| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| _NPerBlock| SrcScalar| |
|
// #################################################################################| | | | | | | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | | | | Stage| | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| | PerVector| |
|
// #################################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
ck
::
tensor_operation
::
device
::
DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
<
NumDimG
,
NumDimM
,
NumDimN
,
NumDimK
,
NumDimO
,
InputDataType
,
InputDataType
,
InputDataType
,
InputDataType
,
GemmDataType
,
ZDataType
,
LSEDataType
,
void
,
void
,
AccDataType
,
ShuffleDataType
,
QKVElementOp
,
QKVElementOp
,
Scale
,
QKVElementOp
,
YElementOp
,
GemmSpec
,
TensorSpecQ
,
TensorSpecK
,
TensorSpecV
,
TensorSpecY
,
1
,
256
,
128
,
128
,
32
,
64
,
32
,
8
,
8
,
2
,
32
,
32
,
1
,
4
,
2
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
S
<
16
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
8
,
1
,
MaskingSpec
,
Deterministic
>
;
// #################################################################################| NumDimG| NumDimM| NumDimN| NumDimK| NumDimO| ADataType| BDataType| B1DataType| CDataType| GemmDataType| ZDataType| LSEDataType| Acc0BiasDataType| Acc1BiasDataType| GemmAcc| CShuffle| A| B| Acc| B1| C| GEMM| ATensorSpec| B0TensorSpec| B1TensorSpec| CTensorSpec| NumGemmK| Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1|Dropout| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| D0BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CShuffleBlockTransferScalarPerVector| D1BlockTransfer| MaskingSpec|
// #################################################################################| | | | | | | | | | | | | | | DataType| DataType| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| | | | | Prefetch| Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| Step| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcScalar| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| _NPerBlock| SrcScalar| |
// #################################################################################| | | | | | | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | | | | Stage| | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| | PerVector| |
// #################################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
ck
::
tensor_operation
::
device
::
DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
<
NumDimG
,
NumDimM
,
NumDimN
,
NumDimK
,
NumDimO
,
InputDataType
,
InputDataType
,
InputDataType
,
InputDataType
,
GemmDataType
,
ZDataType
,
LSEDataType
,
void
,
void
,
AccDataType
,
ShuffleDataType
,
QKVElementOp
,
QKVElementOp
,
Scale
,
QKVElementOp
,
YElementOp
,
GemmSpec
,
TensorSpecQ
,
TensorSpecK
,
TensorSpecV
,
TensorSpecY
,
1
,
256
,
128
,
128
,
32
,
64
,
32
,
8
,
8
,
2
,
32
,
32
,
1
,
4
,
2
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
S
<
16
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
8
,
1
,
MaskingSpec
>
;
using
DeviceGemmInstanceBWD
=
// ########################################################################################| NumDimG| NumDimM| NumDimN| NumDimK| NumDimO| InputDataType| OutputDataType| GemmDataType| ZDataType| LSEDataType| Acc0BiasDataType| Acc1BiasDataType| GemmAcc| CShuffle| A| B| Acc| B1| C| GEMM| ATensorSpec| B0TensorSpec| B1TensorSpec| CTensorSpec| NumGemmK| Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| Gemm2| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1| Gemm2| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| D0BlockTransfer| CShuffle| CShuffle| CBlockTransferClusterLengths| CShuffleBlockTransferScalarPerVector_NPerBlock| MaskingSpec| Deterministic|
...
...
@@ -152,11 +152,11 @@ using DeviceGemmInstanceBWD =
#elif(DIM <= 128)
// clang-format off
using
DeviceGemmInstanceFWD
=
// #################################################################################| NumDimG| NumDimM| NumDimN| NumDimK| NumDimO| ADataType| BDataType| B1DataType| CDataType| GemmDataType| ZDataType| LSEDataType| Acc0BiasDataType| Acc1BiasDataType| GemmAcc| CShuffle| A| B| Acc| B1| C| GEMM| ATensorSpec| B0TensorSpec| B1TensorSpec| CTensorSpec| NumGemmK| Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1|Dropout| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| D0BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CShuffleBlockTransferScalarPerVector| D1BlockTransfer| MaskingSpec|
Deterministic|
// #################################################################################| | | | | | | | | | | | | | | DataType| DataType| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| | | | | Prefetch| Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| Step| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcScalar| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| _NPerBlock| SrcScalar| |
|
// #################################################################################| | | | | | | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | | | | Stage| | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| | PerVector| |
|
// #################################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
ck
::
tensor_operation
::
device
::
DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
<
NumDimG
,
NumDimM
,
NumDimN
,
NumDimK
,
NumDimO
,
InputDataType
,
InputDataType
,
InputDataType
,
InputDataType
,
GemmDataType
,
ZDataType
,
LSEDataType
,
void
,
void
,
AccDataType
,
ShuffleDataType
,
QKVElementOp
,
QKVElementOp
,
Scale
,
QKVElementOp
,
YElementOp
,
GemmSpec
,
TensorSpecQ
,
TensorSpecK
,
TensorSpecV
,
TensorSpecY
,
1
,
256
,
128
,
128
,
32
,
64
,
32
,
8
,
8
,
2
,
32
,
32
,
1
,
4
,
2
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
S
<
16
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
8
,
1
,
MaskingSpec
,
Deterministic
>
;
// #################################################################################| NumDimG| NumDimM| NumDimN| NumDimK| NumDimO| ADataType| BDataType| B1DataType| CDataType| GemmDataType| ZDataType| LSEDataType| Acc0BiasDataType| Acc1BiasDataType| GemmAcc| CShuffle| A| B| Acc| B1| C| GEMM| ATensorSpec| B0TensorSpec| B1TensorSpec| CTensorSpec| NumGemmK| Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1|Dropout| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| D0BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CShuffleBlockTransferScalarPerVector| D1BlockTransfer| MaskingSpec|
// #################################################################################| | | | | | | | | | | | | | | DataType| DataType| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| | | | | Prefetch| Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| Step| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcScalar| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| _NPerBlock| SrcScalar| |
// #################################################################################| | | | | | | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | | | | Stage| | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| | PerVector| |
// #################################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
ck
::
tensor_operation
::
device
::
DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
<
NumDimG
,
NumDimM
,
NumDimN
,
NumDimK
,
NumDimO
,
InputDataType
,
InputDataType
,
InputDataType
,
InputDataType
,
GemmDataType
,
ZDataType
,
LSEDataType
,
void
,
void
,
AccDataType
,
ShuffleDataType
,
QKVElementOp
,
QKVElementOp
,
Scale
,
QKVElementOp
,
YElementOp
,
GemmSpec
,
TensorSpecQ
,
TensorSpecK
,
TensorSpecV
,
TensorSpecY
,
1
,
256
,
128
,
128
,
32
,
64
,
32
,
8
,
8
,
2
,
32
,
32
,
1
,
4
,
2
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
S
<
16
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
8
,
1
,
MaskingSpec
>
;
using
DeviceGemmInstanceBWD
=
// ########################################################################################| NumDimG| NumDimM| NumDimN| NumDimK| NumDimO| InputDataType| OutputDataType| GemmDataType| ZDataType| LSEDataType| Acc0BiasDataType| Acc1BiasDataType| GemmAcc| CShuffle| A| B| Acc| B1| C| GEMM| ATensorSpec| B0TensorSpec| B1TensorSpec| CTensorSpec| NumGemmK| Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| Gemm2| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1| Gemm2| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| D0BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CShuffleBlockTransferScalarPerVector_NPerBlock| MaskingSpec| Deterministic|
...
...
@@ -298,10 +298,11 @@ int run(int argc, char* argv[])
// Overall QKV matrices shape
// y_g_m_o = Softmax(alpha * Q_g_m_k * K_g_k_n) * V_g_n_o
// y_g0_g1_m_o = reshape(y_g_m_o, [G0, G1, M, O])
// y_g0_m_g1_o = permute(y_g0_g1_m_o, [0, 2, 1, 3])
// y_g0_g1
q
_m_o = reshape(y_g_m_o, [G0, G1
Q
, M, O])
// y_g0_m_g1
q
_o = permute(y_g0_g1
q
_m_o, [0, 2, 1, 3])
float
alpha
=
1.
f
/
std
::
sqrt
(
DIM
);
float
p_drop
=
0.2
;
int
h_ratio
=
1
;
// G1Q / G1KV
bool
input_permute
=
true
;
bool
output_permute
=
true
;
...
...
@@ -319,25 +320,26 @@ int run(int argc, char* argv[])
init_method
=
std
::
stoi
(
argv
[
2
]);
time_kernel
=
std
::
stoi
(
argv
[
3
]);
}
else
if
(
argc
==
7
)
else
if
(
argc
==
8
)
{
do_verification
=
std
::
stoi
(
argv
[
1
]);
init_method
=
std
::
stoi
(
argv
[
2
]);
time_kernel
=
std
::
stoi
(
argv
[
3
]);
p_drop
=
std
::
stof
(
argv
[
4
]);
p_drop
=
std
::
stof
(
argv
[
4
]);
h_ratio
=
std
::
stof
(
argv
[
5
]);
input_permute
=
std
::
stoi
(
argv
[
5
]);
output_permute
=
std
::
stoi
(
argv
[
6
]);
input_permute
=
std
::
stoi
(
argv
[
6
]);
output_permute
=
std
::
stoi
(
argv
[
7
]);
}
else
{
printf
(
"arg1: verification (0=no, 1=yes)
\n
"
);
printf
(
"arg2: initialization (0=no init, 1=integer value, 2=decimal value)
\n
"
);
printf
(
"arg3: time kernel (0=no, 1=yes)
\n
"
);
printf
(
"arg4
to 11: M, N, K, O, G0, G1
\n
"
);
printf
(
"arg
10: scale (alpha)
\n
"
);
printf
(
"arg
11
to
12
: input / output permute
\n
"
);
printf
(
"arg4
: p_drop
\n
"
);
printf
(
"arg
5: h_ratio
\n
"
);
printf
(
"arg
6
to
7
: input / output permute
\n
"
);
exit
(
0
);
}
...
...
@@ -407,49 +409,65 @@ int run(int argc, char* argv[])
std
::
size_t
flop_bwd
=
0
,
num_byte_bwd
=
0
;
for
(
std
::
size_t
i
=
0
;
i
<
group_count
;
i
++
)
{
int
M
=
128
*
(
rand
()
%
8
)
+
(
rand
()
%
128
);
int
N
=
128
*
(
rand
()
%
8
)
+
(
rand
()
%
128
);
int
K
=
DIM
;
int
O
=
DIM
;
int
G0
=
rand
()
%
4
+
1
;
int
G1
=
rand
()
%
4
+
1
;
std
::
vector
<
ck
::
index_t
>
q_gs_ms_ks_lengths
{
G0
,
G1
,
M
,
K
};
int
M
=
128
*
(
rand
()
%
8
)
+
(
rand
()
%
128
);
int
N
=
128
*
(
rand
()
%
8
)
+
(
rand
()
%
128
);
int
K
=
DIM
;
int
O
=
DIM
;
int
G0
=
rand
()
%
4
+
1
;
int
G1KV
=
rand
()
%
4
+
1
;
int
G1Q
=
G1KV
*
h_ratio
;
std
::
vector
<
ck
::
index_t
>
q_gs_ms_ks_lengths
{
G0
,
G1Q
,
M
,
K
};
std
::
vector
<
ck
::
index_t
>
q_gs_ms_ks_strides
=
input_permute
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
*
K
,
K
,
G1
*
K
,
1
}
// Q layout [G0, M, G1, K]
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
M
*
K
,
M
*
K
,
K
,
1
};
// Q layout [G0, G1, M, K]
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
Q
*
K
,
K
,
G1
Q
*
K
,
1
}
// Q layout [G0, M, G1
Q
, K]
:
std
::
vector
<
ck
::
index_t
>
{
G1
Q
*
M
*
K
,
M
*
K
,
K
,
1
};
// Q layout [G0, G1
Q
, M, K]
std
::
vector
<
ck
::
index_t
>
k_gs_ns_ks_lengths
{
G0
,
G1
,
N
,
K
};
std
::
vector
<
ck
::
index_t
>
k_gs_ns_ks_lengths
{
G0
,
G1
KV
,
N
,
K
};
std
::
vector
<
ck
::
index_t
>
k_gs_ns_ks_strides
=
input_permute
?
std
::
vector
<
ck
::
index_t
>
{
N
*
G1
*
K
,
K
,
G1
*
K
,
1
}
// K layout [G0, N, G1, K]
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
N
*
K
,
N
*
K
,
K
,
1
};
// K layout [G0, G1, N, K]
?
std
::
vector
<
ck
::
index_t
>
{
N
*
G1KV
*
K
,
K
,
G1KV
*
K
,
1
}
// K layout [G0, N, G1KV, K]
:
std
::
vector
<
ck
::
index_t
>
{
G1KV
*
N
*
K
,
N
*
K
,
K
,
1
};
// K layout [G0, G1KV, N, K]
std
::
vector
<
ck
::
index_t
>
v_gs_os_ns_lengths
{
G0
,
G1
,
O
,
N
};
std
::
vector
<
ck
::
index_t
>
v_gs_os_ns_lengths
{
G0
,
G1
KV
,
O
,
N
};
std
::
vector
<
ck
::
index_t
>
v_gs_os_ns_strides
=
input_permute
?
std
::
vector
<
ck
::
index_t
>
{
N
*
G1
*
O
,
O
,
1
,
G1
*
O
}
// V layout [G0, N, G1, O]
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
N
*
O
,
N
*
O
,
1
,
O
};
// V layout [G0, G1, N, O]
?
std
::
vector
<
ck
::
index_t
>
{
N
*
G1KV
*
O
,
O
,
1
,
G1KV
*
O
}
// V layout [G0, N, G1KV, O]
:
std
::
vector
<
ck
::
index_t
>
{
G1KV
*
N
*
O
,
N
*
O
,
1
,
O
};
// V layout [G0, G1KV, N, O]
std
::
vector
<
ck
::
index_t
>
y_gs_ms_os_lengths
{
G0
,
G1
,
M
,
O
};
std
::
vector
<
ck
::
index_t
>
y_gs_ms_os_lengths
{
G0
,
G1
Q
,
M
,
O
};
std
::
vector
<
ck
::
index_t
>
y_gs_ms_os_strides
=
output_permute
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
*
O
,
O
,
G1
*
O
,
1
}
// Y layout [G0, M, G1, O]
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
M
*
O
,
M
*
O
,
O
,
1
};
// Y layout [G0, G1, M, O]
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
Q
*
O
,
O
,
G1
Q
*
O
,
1
}
// Y layout [G0, M, G1
Q
, O]
:
std
::
vector
<
ck
::
index_t
>
{
G1
Q
*
M
*
O
,
M
*
O
,
O
,
1
};
// Y layout [G0, G1
Q
, M, O]
std
::
vector
<
ck
::
index_t
>
z_gs_ms_ns_lengths
{
G0
,
G1
,
M
,
N
};
std
::
vector
<
ck
::
index_t
>
z_gs_ms_ns_lengths
{
G0
,
G1
Q
,
M
,
N
};
std
::
vector
<
ck
::
index_t
>
z_gs_ms_ns_strides
=
input_permute
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
*
N
,
N
,
G1
*
N
,
1
}
// Z layout [G0, M, G1, N]
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
M
*
N
,
M
*
N
,
N
,
1
};
// Z layout [G0, G1, M, N]
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1Q
*
N
,
N
,
G1Q
*
N
,
1
}
// Z layout [G0, M, G1Q, N]
:
std
::
vector
<
ck
::
index_t
>
{
G1Q
*
M
*
N
,
M
*
N
,
N
,
1
};
// Z layout [G0, G1Q, M, N]
std
::
vector
<
ck
::
index_t
>
kgrad_gs_ns_ks_lengths
{
G0
,
G1Q
,
N
,
K
};
std
::
vector
<
ck
::
index_t
>
kgrad_gs_ns_ks_strides
=
input_permute
?
std
::
vector
<
ck
::
index_t
>
{
N
*
G1Q
*
K
,
K
,
G1Q
*
K
,
1
}
// KGrad layout [G0, N, G1Q, K]
:
std
::
vector
<
ck
::
index_t
>
{
G1Q
*
N
*
K
,
N
*
K
,
K
,
1
};
// KGrad layout [G0, G1Q, N, K]
std
::
vector
<
ck
::
index_t
>
vgrad_gs_os_ns_lengths
{
G0
,
G1Q
,
O
,
N
};
std
::
vector
<
ck
::
index_t
>
vgrad_gs_os_ns_strides
=
input_permute
?
std
::
vector
<
ck
::
index_t
>
{
N
*
G1Q
*
O
,
O
,
1
,
G1Q
*
O
}
// VGrad layout [G0, N, G1Q, O]
:
std
::
vector
<
ck
::
index_t
>
{
G1Q
*
N
*
O
,
N
*
O
,
1
,
O
};
// VGrad layout [G0, G1Q, N, O]
// The softmax stat log-sum-exp (LSE) is used to speed up softmax calculation in backward
// pass Pi = exp(Si) / sum(exp(S0) + exp(S1) + ...)
// = exp(Si) / exp(log(sum(exp() + ...)))
// = exp(Si - log(sum(exp() + ...)))
// ^^^^^^^^^^^^^^^^^^^^^
// LSE
std
::
vector
<
ck
::
index_t
>
lse_gs_ms_lengths
{
G0
,
G1
,
M
};
std
::
vector
<
ck
::
index_t
>
lse_gs_ms_strides
{
G1
*
M
,
M
,
1
};
// LSE layout [G0, G1, M]
std
::
vector
<
ck
::
index_t
>
lse_gs_ms_lengths
{
G0
,
G1
Q
,
M
};
std
::
vector
<
ck
::
index_t
>
lse_gs_ms_strides
{
G1
Q
*
M
,
M
,
1
};
// LSE layout [G0, G1
Q
, M]
problem_descs_fwd
.
push_back
({
q_gs_ms_ks_lengths
,
q_gs_ms_ks_strides
,
...
...
@@ -481,13 +499,17 @@ int run(int argc, char* argv[])
y_gs_ms_os_strides
,
lse_gs_ms_lengths
,
lse_gs_ms_strides
,
kgrad_gs_ns_ks_lengths
,
kgrad_gs_ns_ks_strides
,
vgrad_gs_os_ns_lengths
,
vgrad_gs_os_ns_strides
,
{},
// std::array<std::vector<ck::index_t>, 1>{acc0_biases_gs_ms_ns_lengths},
{},
// std::array<std::vector<ck::index_t>, 1>{acc0_biases_gs_ms_ns_strides},
{},
// std::array<std::vector<ck::index_t>, 1>{acc1_biases_gs_ms_os_lengths},
{},
// std::array<std::vector<ck::index_t>, 1>{acc1_biases_gs_ms_os_strides},
});
int
BatchCount
=
G0
*
G1
;
int
BatchCount
=
G0
*
G1
Q
;
flop_fwd
+=
(
size_t
(
M
)
*
N
*
K
*
2
+
size_t
(
M
)
*
N
*
O
*
2
)
*
BatchCount
;
num_byte_fwd
+=
(
sizeof
(
InputDataType
)
*
M
*
K
+
sizeof
(
InputDataType
)
*
K
*
N
+
sizeof
(
InputDataType
)
*
N
*
O
+
sizeof
(
InputDataType
)
*
M
*
O
)
*
...
...
@@ -510,6 +532,8 @@ int run(int argc, char* argv[])
Tensor
<
InputDataType
>
y_gs_ms_os
(
y_gs_ms_os_lengths
,
y_gs_ms_os_strides
);
Tensor
<
InputDataType
>
ygrad_gs_ms_os
(
y_gs_ms_os_lengths
,
y_gs_ms_os_strides
);
Tensor
<
LSEDataType
>
lse_gs_ms
(
lse_gs_ms_lengths
,
lse_gs_ms_strides
);
Tensor
<
OutputDataType
>
kgrad_gs_ns_ks
(
kgrad_gs_ns_ks_lengths
,
kgrad_gs_ns_ks_strides
);
Tensor
<
OutputDataType
>
vgrad_gs_os_ns
(
vgrad_gs_os_ns_lengths
,
vgrad_gs_os_ns_strides
);
if
(
i
<
4
)
{
std
::
cout
<<
"q_gs_ms_ks: "
<<
q_gs_ms_ks
.
mDesc
<<
std
::
endl
;
...
...
@@ -518,6 +542,8 @@ int run(int argc, char* argv[])
std
::
cout
<<
"v_gs_os_ns: "
<<
v_gs_os_ns
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"y_gs_ms_os: "
<<
y_gs_ms_os
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"lse_gs_ms_os: "
<<
lse_gs_ms
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"kgrad_gs_ns_ks: "
<<
kgrad_gs_ns_ks
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"vgrad_gs_os_ns: "
<<
vgrad_gs_os_ns
.
mDesc
<<
std
::
endl
;
}
z_fwd_gs_ms_ns
.
GenerateTensorValue
(
GeneratorTensor_1
<
InputDataType
>
{
0
});
z_bwd_gs_ms_ns
.
GenerateTensorValue
(
GeneratorTensor_1
<
InputDataType
>
{
0
});
...
...
@@ -552,14 +578,16 @@ int run(int argc, char* argv[])
q_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_1
<
InputDataType
>
{
1
});
k_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
InputDataType
>
{});
v_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
InputDataType
>
{});
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_Sequential
<
2
>
{});
// dy[g0, g1, m, o]
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_Sequential
<
2
>
{});
// dy[g0, g1q, m, o]
// dO dot O = [0; 1; 2; ...]
break
;
case
6
:
q_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_1
<
InputDataType
>
{
1
});
k_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
InputDataType
>
{});
v_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
InputDataType
>
{});
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_Sequential
<
3
>
{});
// dy[g0, g1, m, o]
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_Sequential
<
3
>
{});
// dy[g0, g1q, m, o]
// assume mnko = 256
// P = softmax(QK) = 0.0039 * ones
// O = P V = 0.0039 * ones
...
...
@@ -573,7 +601,7 @@ int run(int argc, char* argv[])
k_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
InputDataType
>
{});
v_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
InputDataType
>
{});
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_1
<
InputDataType
>
{
1
});
// dy[g0, g1, m, o]
GeneratorTensor_1
<
InputDataType
>
{
1
});
// dy[g0, g1
q
, m, o]
// assume mnko = 256
// P = softmax(QK) = 0.0039 * ones
// O = P V = 0.0039 * ones
...
...
@@ -596,13 +624,21 @@ int run(int argc, char* argv[])
Tensor
<
InputDataType
>
p_drop_g_m_n
({
BatchCount
,
M
,
N
});
q_gs_ms_ks
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
q_g_m_k
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
q_g_m_k
(
idx
[
0
]
*
G1
Q
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
});
k_gs_ns_ks
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
k_g_n_k
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
k_g_n_k
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
const
size_t
&
g0
=
idx
[
0
]
/
G1Q
;
const
size_t
&
g1q
=
idx
[
0
]
%
G1Q
;
const
size_t
&
g1kv
=
g1q
/
h_ratio
;
self
(
idx
)
=
k_gs_ns_ks
(
g0
,
g1kv
,
idx
[
1
],
idx
[
2
]);
});
v_gs_os_ns
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
v_g_n_o
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
3
],
idx
[
2
])
=
self
(
idx
);
v_g_n_o
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
const
size_t
&
g0
=
idx
[
0
]
/
G1Q
;
const
size_t
&
g1q
=
idx
[
0
]
%
G1Q
;
const
size_t
&
g1kv
=
g1q
/
h_ratio
;
self
(
idx
)
=
v_gs_os_ns
(
g0
,
g1kv
,
idx
[
2
],
idx
[
1
]);
});
q_g_m_ks
.
push_back
(
q_g_m_k
);
...
...
@@ -624,6 +660,8 @@ int run(int argc, char* argv[])
z_bwd_tensors
.
push_back
(
z_bwd_gs_ms_ns
);
lse_tensors
.
push_back
(
lse_gs_ms
);
ygrad_tensors
.
push_back
(
ygrad_gs_ms_os
);
kgrad_tensors
.
push_back
(
kgrad_gs_ns_ks
);
vgrad_tensors
.
push_back
(
vgrad_gs_os_ns
);
q_tensors_device
.
emplace_back
(
std
::
make_unique
<
DeviceMem
>
(
sizeof
(
InputDataType
)
*
q_gs_ms_ks
.
GetElementSpaceSize
()));
...
...
@@ -641,10 +679,10 @@ int run(int argc, char* argv[])
std
::
make_unique
<
DeviceMem
>
(
sizeof
(
LSEDataType
)
*
lse_gs_ms
.
GetElementSpaceSize
()));
qgrad_tensors_device
.
emplace_back
(
std
::
make_unique
<
DeviceMem
>
(
sizeof
(
OutputDataType
)
*
q_gs_ms_ks
.
GetElementSpaceSize
()));
kgrad_tensors_device
.
emplace_back
(
std
::
make_unique
<
DeviceMem
>
(
sizeof
(
OutputDataType
)
*
k_gs_ns_ks
.
GetElementSpaceSize
()));
vgrad_tensors_device
.
emplace_back
(
std
::
make_unique
<
DeviceMem
>
(
sizeof
(
OutputDataType
)
*
v_gs_os_ns
.
GetElementSpaceSize
()));
kgrad_tensors_device
.
emplace_back
(
std
::
make_unique
<
DeviceMem
>
(
sizeof
(
OutputDataType
)
*
k
grad
_gs_ns_ks
.
GetElementSpaceSize
()));
vgrad_tensors_device
.
emplace_back
(
std
::
make_unique
<
DeviceMem
>
(
sizeof
(
OutputDataType
)
*
v
grad
_gs_os_ns
.
GetElementSpaceSize
()));
ygrad_tensors_device
.
emplace_back
(
std
::
make_unique
<
DeviceMem
>
(
sizeof
(
InputDataType
)
*
y_gs_ms_os
.
GetElementSpaceSize
()));
...
...
@@ -840,15 +878,15 @@ int run(int argc, char* argv[])
for
(
std
::
size_t
i
=
0
;
i
<
group_count
;
i
++
)
{
int
G1
=
v
_tensors
[
i
].
GetLengths
()[
1
];
int
G1
Q
=
q
_tensors
[
i
].
GetLengths
()[
1
];
// copy z matirx data form device
z_fwd_tensors_device
[
i
]
->
FromDevice
(
z_fwd_tensors
[
i
].
mData
.
data
());
z_fwd_tensors
[
i
].
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
z_fwd_g_m_ns
[
i
](
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
z_fwd_g_m_ns
[
i
](
idx
[
0
]
*
G1
Q
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
});
z_bwd_tensors_device
[
i
]
->
FromDevice
(
z_bwd_tensors
[
i
].
mData
.
data
());
z_bwd_tensors
[
i
].
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
z_bwd_g_m_ns
[
i
](
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
z_bwd_g_m_ns
[
i
](
idx
[
0
]
*
G1
Q
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
});
run_attention_fwd_host
(
q_g_m_ks
[
i
],
k_g_n_ks
[
i
],
...
...
@@ -863,12 +901,12 @@ int run(int argc, char* argv[])
p_dropout_in_uint8_t
,
rp_dropout
);
int
G0
=
v
_tensors
[
i
].
GetLengths
()[
0
];
int
G0
=
q
_tensors
[
i
].
GetLengths
()[
0
];
int
O
=
v_tensors
[
i
].
GetLengths
()[
2
];
int
N
=
v_tensors
[
i
].
GetLengths
()[
3
];
int
M
=
q_tensors
[
i
].
GetLengths
()[
2
];
int
K
=
q_tensors
[
i
].
GetLengths
()[
3
];
int
BatchCount
=
G0
*
G1
;
int
BatchCount
=
G0
*
G1
Q
;
Tensor
<
OutputDataType
>
qgrad_g_m_k
({
BatchCount
,
M
,
K
});
Tensor
<
OutputDataType
>
kgrad_g_n_k
({
BatchCount
,
N
,
K
});
Tensor
<
OutputDataType
>
vgrad_g_n_o
({
BatchCount
,
N
,
O
});
...
...
@@ -878,7 +916,7 @@ int run(int argc, char* argv[])
Tensor
<
InputDataType
>
ygrad_g_m_o
({
BatchCount
,
M
,
O
});
ygrad_tensors
[
i
].
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
ygrad_g_m_o
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
ygrad_g_m_o
(
idx
[
0
]
*
G1
Q
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
});
auto
ref_gemm0_grad
=
ReferenceGemm0GradInstance
{};
auto
ref_gemm0_grad_invoker
=
ref_gemm0_grad
.
MakeInvoker
();
...
...
@@ -921,10 +959,10 @@ int run(int argc, char* argv[])
Tensor
<
OutputDataType
>
qgrad_gs_ms_ks_host_result
(
q_tensors
[
i
].
GetLengths
(),
q_tensors
[
i
].
GetStrides
());
Tensor
<
OutputDataType
>
kgrad_gs_ns_ks_host_result
(
k_tensors
[
i
].
GetLengths
(),
k_tensors
[
i
].
GetStrides
());
Tensor
<
OutputDataType
>
vgrad_gs_os_ns_host_result
(
v_tensors
[
i
].
GetLengths
(),
v_tensors
[
i
].
GetStrides
());
Tensor
<
OutputDataType
>
kgrad_gs_ns_ks_host_result
(
k
grad
_tensors
[
i
].
GetLengths
(),
k
grad
_tensors
[
i
].
GetStrides
());
Tensor
<
OutputDataType
>
vgrad_gs_os_ns_host_result
(
v
grad
_tensors
[
i
].
GetLengths
(),
v
grad
_tensors
[
i
].
GetStrides
());
Tensor
<
InputDataType
>
y_gs_ms_os_host_result
(
y_tensors
[
i
].
GetLengths
(),
y_tensors
[
i
].
GetStrides
());
Tensor
<
LSEDataType
>
lse_gs_ms_host_result
(
lse_tensors
[
i
].
GetLengths
(),
...
...
@@ -932,10 +970,10 @@ int run(int argc, char* argv[])
Tensor
<
OutputDataType
>
qgrad_gs_ms_ks_device_result
(
q_tensors
[
i
].
GetLengths
(),
q_tensors
[
i
].
GetStrides
());
Tensor
<
OutputDataType
>
kgrad_gs_ns_ks_device_result
(
k_tensors
[
i
].
GetLengths
(),
k_tensors
[
i
].
GetStrides
());
Tensor
<
OutputDataType
>
vgrad_gs_os_ns_device_result
(
v_tensors
[
i
].
GetLengths
(),
v_tensors
[
i
].
GetStrides
());
Tensor
<
OutputDataType
>
kgrad_gs_ns_ks_device_result
(
k
grad
_tensors
[
i
].
GetLengths
(),
k
grad
_tensors
[
i
].
GetStrides
());
Tensor
<
OutputDataType
>
vgrad_gs_os_ns_device_result
(
v
grad
_tensors
[
i
].
GetLengths
(),
v
grad
_tensors
[
i
].
GetStrides
());
Tensor
<
InputDataType
>
y_gs_ms_os_device_result
(
y_tensors
[
i
].
GetLengths
(),
y_tensors
[
i
].
GetStrides
());
Tensor
<
LSEDataType
>
lse_gs_ms_device_result
(
lse_tensors
[
i
].
GetLengths
(),
...
...
@@ -949,42 +987,42 @@ int run(int argc, char* argv[])
// permute
y_gs_ms_os_host_result
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
const
size_t
&
g0
=
idx
[
0
];
const
size_t
&
g1
=
idx
[
1
];
const
size_t
&
g0
=
idx
[
0
];
const
size_t
&
g1
q
=
idx
[
1
];
const
size_t
g
=
g0
*
G1
+
g1
;
const
size_t
g
=
g0
*
G1
Q
+
g1
q
;
self
(
idx
)
=
y_g_m_os
[
i
](
g
,
idx
[
2
],
idx
[
3
]);
});
lse_gs_ms_host_result
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
const
size_t
&
g0
=
idx
[
0
];
const
size_t
&
g1
=
idx
[
1
];
const
size_t
&
g0
=
idx
[
0
];
const
size_t
&
g1
q
=
idx
[
1
];
const
size_t
g
=
g0
*
G1
+
g1
;
const
size_t
g
=
g0
*
G1
Q
+
g1
q
;
self
(
idx
)
=
lse_g_ms
[
i
](
g
,
idx
[
2
]);
});
qgrad_gs_ms_ks_host_result
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
const
size_t
&
g0
=
idx
[
0
];
const
size_t
&
g1
=
idx
[
1
];
const
size_t
&
g0
=
idx
[
0
];
const
size_t
&
g1
q
=
idx
[
1
];
const
size_t
g
=
g0
*
G1
+
g1
;
const
size_t
g
=
g0
*
G1
Q
+
g1
q
;
self
(
idx
)
=
qgrad_g_m_k
(
g
,
idx
[
2
],
idx
[
3
]);
});
kgrad_gs_ns_ks_host_result
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
const
size_t
&
g0
=
idx
[
0
];
const
size_t
&
g1
=
idx
[
1
];
const
size_t
&
g0
=
idx
[
0
];
const
size_t
&
g1
q
=
idx
[
1
];
const
size_t
g
=
g0
*
G1
+
g1
;
const
size_t
g
=
g0
*
G1
Q
+
g1
q
;
self
(
idx
)
=
kgrad_g_n_k
(
g
,
idx
[
2
],
idx
[
3
]);
});
vgrad_gs_os_ns_host_result
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
const
size_t
&
g0
=
idx
[
0
];
const
size_t
&
g1
=
idx
[
1
];
const
size_t
&
g0
=
idx
[
0
];
const
size_t
&
g1
q
=
idx
[
1
];
const
size_t
g
=
g0
*
G1
+
g1
;
const
size_t
g
=
g0
*
G1
Q
+
g1
q
;
self
(
idx
)
=
vgrad_g_n_o
(
g
,
idx
[
3
],
idx
[
2
]);
});
...
...
example/32_batched_gemm_scale_softmax_gemm/run_batched_multihead_attention_forward.inc
View file @
e87ddb0e
...
...
@@ -14,11 +14,12 @@ int run(int argc, char* argv[])
ck
::
index_t
K
=
DIM
;
ck
::
index_t
O
=
DIM
;
// Output shape C[G0, M, G1, O]. Batch dim, outer dim, inner dim must match GEMM shape
// C_g0_g1_m_o = reshape(C_g_m_o, [g0, g1, m, o])
// C_g0_m_g1_o = permute(C_g0_g1_m_o, [0, 2, 1, 3])
ck
::
index_t
G0
=
7
;
ck
::
index_t
G1
=
13
;
// Output shape C[G0, M, G1Q, O]. Batch dim, outer dim, inner dim must match GEMM shape
// C_g0_g1q_m_o = reshape(C_g_m_o, [g0, g1q, m, o])
// C_g0_m_g1q_o = permute(C_g0_g1q_m_o, [0, 2, 1, 3])
ck
::
index_t
G0
=
7
;
ck
::
index_t
G1Q
=
12
;
// h_q
ck
::
index_t
G1KV
=
12
;
// h_kv
bool
input_permute
=
false
;
bool
output_permute
=
true
;
...
...
@@ -37,32 +38,33 @@ int run(int argc, char* argv[])
init_method
=
std
::
stoi
(
argv
[
2
]);
time_kernel
=
std
::
stoi
(
argv
[
3
]);
}
else
if
(
argc
==
1
3
)
else
if
(
argc
==
1
4
)
{
do_verification
=
std
::
stoi
(
argv
[
1
]);
init_method
=
std
::
stoi
(
argv
[
2
]);
time_kernel
=
std
::
stoi
(
argv
[
3
]);
M
=
std
::
stoi
(
argv
[
4
]);
N
=
std
::
stoi
(
argv
[
5
]);
K
=
std
::
stoi
(
argv
[
6
]);
O
=
std
::
stoi
(
argv
[
7
]);
G0
=
std
::
stoi
(
argv
[
8
]);
G1
=
std
::
stoi
(
argv
[
9
]);
M
=
std
::
stoi
(
argv
[
4
]);
N
=
std
::
stoi
(
argv
[
5
]);
K
=
std
::
stoi
(
argv
[
6
]);
O
=
std
::
stoi
(
argv
[
7
]);
G0
=
std
::
stoi
(
argv
[
8
]);
G1Q
=
std
::
stoi
(
argv
[
9
]);
G1KV
=
std
::
stoi
(
argv
[
10
]);
p_drop
=
std
::
stof
(
argv
[
1
0
]);
p_drop
=
std
::
stof
(
argv
[
1
1
]);
input_permute
=
std
::
stoi
(
argv
[
1
1
]);
output_permute
=
std
::
stoi
(
argv
[
1
2
]);
input_permute
=
std
::
stoi
(
argv
[
1
2
]);
output_permute
=
std
::
stoi
(
argv
[
1
3
]);
}
else
{
printf
(
"arg1: verification (0=no, 1=yes)
\n
"
);
printf
(
"arg2: initialization (0=no init, 1=integer value, 2=decimal value)
\n
"
);
printf
(
"arg3: time kernel (0=no, 1=yes)
\n
"
);
printf
(
"arg4 to 1
1
: M, N, K, O, G0, G1
\n
"
);
printf
(
"arg1
0
:
scale (alpha)
\n
"
);
printf
(
"arg1
1
to 1
2
: input / output permute
\n
"
);
printf
(
"arg4 to 1
0
: M, N, K, O, G0, G1
Q, G1KV
\n
"
);
printf
(
"arg1
1
:
p_drop
\n
"
);
printf
(
"arg1
2
to 1
3
: input / output permute
\n
"
);
exit
(
0
);
}
...
...
@@ -71,39 +73,39 @@ int run(int argc, char* argv[])
float
rp_dropout
=
1.0
/
p_dropout
;
float
alpha
=
1.
f
/
std
::
sqrt
(
K
);
std
::
vector
<
ck
::
index_t
>
a_gs_ms_ks_lengths
{
G0
,
G1
,
M
,
K
};
std
::
vector
<
ck
::
index_t
>
a_gs_ms_ks_lengths
{
G0
,
G1
Q
,
M
,
K
};
std
::
vector
<
ck
::
index_t
>
a_gs_ms_ks_strides
=
input_permute
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
*
K
,
K
,
G1
*
K
,
1
}
// A layout [G0, M, G1, K]
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
M
*
K
,
M
*
K
,
K
,
1
};
// A layout [G0, G1, M, K]
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
Q
*
K
,
K
,
G1
Q
*
K
,
1
}
// A layout [G0, M, G1
Q
, K]
:
std
::
vector
<
ck
::
index_t
>
{
G1
Q
*
M
*
K
,
M
*
K
,
K
,
1
};
// A layout [G0, G1
Q
, M, K]
std
::
vector
<
ck
::
index_t
>
b0_gs_ns_ks_lengths
{
G0
,
G1
,
N
,
K
};
std
::
vector
<
ck
::
index_t
>
b0_gs_ns_ks_lengths
{
G0
,
G1
KV
,
N
,
K
};
std
::
vector
<
ck
::
index_t
>
b0_gs_ns_ks_strides
=
input_permute
?
std
::
vector
<
ck
::
index_t
>
{
N
*
G1
*
K
,
K
,
G1
*
K
,
1
}
// B0 layout [G0, N, G1, K]
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
N
*
K
,
N
*
K
,
K
,
1
};
// B0 layout [G0, G1, N, K]
?
std
::
vector
<
ck
::
index_t
>
{
N
*
G1
KV
*
K
,
K
,
G1
KV
*
K
,
1
}
// B0 layout [G0, N, G1
KV
, K]
:
std
::
vector
<
ck
::
index_t
>
{
G1
KV
*
N
*
K
,
N
*
K
,
K
,
1
};
// B0 layout [G0, G1
KV
, N, K]
std
::
vector
<
ck
::
index_t
>
b1_gs_os_ns_lengths
{
G0
,
G1
,
O
,
N
};
std
::
vector
<
ck
::
index_t
>
b1_gs_os_ns_lengths
{
G0
,
G1
KV
,
O
,
N
};
std
::
vector
<
ck
::
index_t
>
b1_gs_os_ns_strides
=
input_permute
?
std
::
vector
<
ck
::
index_t
>
{
N
*
G1
*
O
,
O
,
1
,
G1
*
O
}
// B1 layout [G0, N, G1, O]
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
N
*
O
,
N
*
O
,
1
,
O
};
// B1 layout [G0, G1, N, O]
?
std
::
vector
<
ck
::
index_t
>
{
N
*
G1
KV
*
O
,
O
,
1
,
G1
KV
*
O
}
// B1 layout [G0, N, G1
KV
, O]
:
std
::
vector
<
ck
::
index_t
>
{
G1
KV
*
N
*
O
,
N
*
O
,
1
,
O
};
// B1 layout [G0, G1
KV
, N, O]
std
::
vector
<
ck
::
index_t
>
c_gs_ms_os_lengths
{
G0
,
G1
,
M
,
O
};
std
::
vector
<
ck
::
index_t
>
c_gs_ms_os_lengths
{
G0
,
G1
Q
,
M
,
O
};
std
::
vector
<
ck
::
index_t
>
c_gs_ms_os_strides
=
output_permute
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
*
O
,
O
,
G1
*
O
,
1
}
// C layout [G0, M, G1, O]
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
M
*
O
,
M
*
O
,
O
,
1
};
// C layout [G0, G1, M, O]
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
Q
*
O
,
O
,
G1
Q
*
O
,
1
}
// C layout [G0, M, G1
Q
, O]
:
std
::
vector
<
ck
::
index_t
>
{
G1
Q
*
M
*
O
,
M
*
O
,
O
,
1
};
// C layout [G0, G1
Q
, M, O]
std
::
vector
<
ck
::
index_t
>
z_gs_ms_ns_lengths
{
G0
,
G1
,
M
,
N
};
std
::
vector
<
ck
::
index_t
>
z_gs_ms_ns_lengths
{
G0
,
G1
Q
,
M
,
N
};
std
::
vector
<
ck
::
index_t
>
z_gs_ms_ns_strides
=
input_permute
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
*
N
,
N
,
G1
*
N
,
1
}
// Z layout [G0, M, G1, N]
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
M
*
N
,
M
*
N
,
N
,
1
};
// Z layout [G0, G1, M, N]
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
Q
*
N
,
N
,
G1
Q
*
N
,
1
}
// Z layout [G0, M, G1
Q
, N]
:
std
::
vector
<
ck
::
index_t
>
{
G1
Q
*
M
*
N
,
M
*
N
,
N
,
1
};
// Z layout [G0, G1
Q
, M, N]
std
::
vector
<
ck
::
index_t
>
lse_gs_ms_lengths
{
G0
,
G1
,
M
};
std
::
vector
<
ck
::
index_t
>
lse_gs_ms_lengths
{
G0
,
G1
Q
,
M
};
std
::
vector
<
ck
::
index_t
>
lse_gs_ms_strides
=
std
::
vector
<
ck
::
index_t
>
{
G1
*
M
,
M
,
1
};
// LSE layout [G0, G1, M]
std
::
vector
<
ck
::
index_t
>
{
G1
Q
*
M
,
M
,
1
};
// LSE layout [G0, G1
Q
, M]
Tensor
<
ADataType
>
a_gs_ms_ks
(
a_gs_ms_ks_lengths
,
a_gs_ms_ks_strides
);
Tensor
<
B0DataType
>
b0_gs_ns_ks
(
b0_gs_ns_ks_lengths
,
b0_gs_ns_ks_strides
);
...
...
@@ -211,7 +213,7 @@ int run(int argc, char* argv[])
return
0
;
}
ck
::
index_t
BatchCount
=
G0
*
G1
;
ck
::
index_t
BatchCount
=
G0
*
G1
Q
;
float
ave_time
=
invoker
.
Run
(
argument
,
StreamConfig
{
nullptr
,
time_kernel
});
...
...
@@ -276,24 +278,32 @@ int run(int argc, char* argv[])
Tensor
<
B1DataType
>
b1_g_n_o
({
BatchCount
,
N
,
O
});
Tensor
<
AccDataType
>
acc0_g_m_n
({
BatchCount
,
M
,
N
});
// scratch object after gemm0
Tensor
<
ADataType
>
a1_g_m_n
({
BatchCount
,
M
,
N
});
// scratch object after softmax
Tensor
<
ADataType
>
a1_g_m_n_drop
({
G0
*
G1
,
M
,
N
});
Tensor
<
ADataType
>
a1_g_m_n_drop
({
BatchCount
,
M
,
N
});
Tensor
<
LSEDataType
>
lse_g_m_host_result
(
{
BatchCount
,
M
});
// scratch object after max + ln(sum)
Tensor
<
ZDataType
>
z_g_m_n
({
G0
*
G1
,
M
,
N
});
Tensor
<
ZDataType
>
z_g_m_n
({
BatchCount
,
M
,
N
});
Tensor
<
CDataType
>
c_g_m_o_host_result
({
BatchCount
,
M
,
O
});
// scratch object after gemm1
// permute
a_gs_ms_ks
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
a_g_m_k
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
a_g_m_k
(
idx
[
0
]
*
G1
Q
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
});
b0_gs_ns_ks
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
b0_g_k_n
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
3
],
idx
[
2
])
=
self
(
idx
);
b0_g_k_n
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
const
size_t
&
g0
=
idx
[
0
]
/
G1Q
;
const
size_t
&
g1q
=
idx
[
0
]
%
G1Q
;
const
size_t
&
g1kv
=
g1q
/
(
G1Q
/
G1KV
);
self
(
idx
)
=
b0_gs_ns_ks
(
g0
,
g1kv
,
idx
[
2
],
idx
[
1
]);
});
b1_gs_os_ns
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
b1_g_n_o
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
3
],
idx
[
2
])
=
self
(
idx
);
b1_g_n_o
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
const
size_t
&
g0
=
idx
[
0
]
/
G1Q
;
const
size_t
&
g1q
=
idx
[
0
]
%
G1Q
;
const
size_t
&
g1kv
=
g1q
/
(
G1Q
/
G1KV
);
self
(
idx
)
=
b1_gs_os_ns
(
g0
,
g1kv
,
idx
[
2
],
idx
[
1
]);
});
z_gs_ms_ns
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
z_g_m_n
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
z_g_m_n
(
idx
[
0
]
*
G1
Q
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
});
// gemm 0
...
...
@@ -340,18 +350,18 @@ int run(int argc, char* argv[])
// permute
c_gs_ms_os_host_result
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
const
size_t
&
g0
=
idx
[
0
];
const
size_t
&
g1
=
idx
[
1
];
const
size_t
&
g0
=
idx
[
0
];
const
size_t
&
g1
q
=
idx
[
1
];
const
size_t
g
=
g0
*
G1
+
g1
;
const
size_t
g
=
g0
*
G1
Q
+
g1
q
;
self
(
idx
)
=
c_g_m_o_host_result
(
g
,
idx
[
2
],
idx
[
3
]);
});
lse_gs_ms_host_result
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
const
size_t
&
g0
=
idx
[
0
];
const
size_t
&
g1
=
idx
[
1
];
const
size_t
&
g0
=
idx
[
0
];
const
size_t
&
g1
q
=
idx
[
1
];
const
size_t
g
=
g0
*
G1
+
g1
;
const
size_t
g
=
g0
*
G1
Q
+
g1
q
;
self
(
idx
)
=
lse_g_m_host_result
(
g
,
idx
[
2
]);
});
...
...
example/32_batched_gemm_scale_softmax_gemm/run_grouped_multihead_attention_forward.inc
View file @
e87ddb0e
...
...
@@ -11,6 +11,7 @@ int run(int argc, char* argv[])
bool
output_permute
=
true
;
float
p_drop
=
0.2
;
int
h_ratio
=
1
;
// G1Q / G1KV
const
unsigned
long
long
seed
=
1
;
const
unsigned
long
long
offset
=
0
;
...
...
@@ -24,22 +25,25 @@ int run(int argc, char* argv[])
init_method
=
std
::
stoi
(
argv
[
2
]);
time_kernel
=
std
::
stoi
(
argv
[
3
]);
}
else
if
(
argc
==
7
)
else
if
(
argc
==
8
)
{
do_verification
=
std
::
stoi
(
argv
[
1
]);
init_method
=
std
::
stoi
(
argv
[
2
]);
time_kernel
=
std
::
stoi
(
argv
[
3
]);
p_drop
=
std
::
stoi
(
argv
[
4
]);
input_permute
=
std
::
stoi
(
argv
[
5
]);
output_permute
=
std
::
stoi
(
argv
[
6
]);
h_ratio
=
std
::
stof
(
argv
[
5
]);
input_permute
=
std
::
stoi
(
argv
[
6
]);
output_permute
=
std
::
stoi
(
argv
[
7
]);
}
else
{
printf
(
"arg1: verification (0=no, 1=yes)
\n
"
);
printf
(
"arg2: initialization (0=no init, 1=integer value, 2=decimal value)
\n
"
);
printf
(
"arg3: time kernel (0=no, 1=yes)
\n
"
);
printf
(
"arg4 to 5: input / output permute
\n
"
);
printf
(
"arg4: p_drop
\n
"
);
printf
(
"arg5: h_ratio
\n
"
);
printf
(
"arg6 to 7: input / output permute
\n
"
);
exit
(
0
);
}
...
...
@@ -60,7 +64,7 @@ int run(int argc, char* argv[])
std
::
vector
<
void
*>
p_z
;
// for result verification
std
::
vector
<
void
*>
p_z_nullptr
;
// for time test
std
::
vector
<
void
*>
p_lse
;
std
::
vector
<
std
::
vector
<
int
>>
g0_g1_m_n_k_o
;
std
::
vector
<
std
::
vector
<
int
>>
g0_g1
q
_m_n_k_o
;
std
::
vector
<
Tensor
<
ADataType
>>
a_tensors
;
std
::
vector
<
Tensor
<
B0DataType
>>
b0_tensors
;
...
...
@@ -83,48 +87,51 @@ int run(int argc, char* argv[])
for
(
std
::
size_t
i
=
0
;
i
<
group_count
;
i
++
)
{
int
M
=
128
*
(
rand
()
%
8
)
+
(
rand
()
%
128
);
int
N
=
128
*
(
rand
()
%
8
)
+
(
rand
()
%
128
);
int
K
=
DIM
;
int
O
=
DIM
;
int
G0
=
rand
()
%
3
+
1
;
int
G1
=
rand
()
%
5
+
1
;
int
M
=
128
*
(
rand
()
%
8
)
+
(
rand
()
%
128
);
int
N
=
128
*
(
rand
()
%
8
)
+
(
rand
()
%
128
);
int
K
=
DIM
;
int
O
=
DIM
;
int
G0
=
rand
()
%
3
+
1
;
int
G1KV
=
rand
()
%
5
+
1
;
int
G1Q
=
G1KV
*
h_ratio
;
g0_g1_m_n_k_o
.
push_back
({
G0
,
G1
,
M
,
N
,
K
,
O
});
g0_g1
q
_m_n_k_o
.
push_back
({
G0
,
G1
Q
,
M
,
N
,
K
,
O
});
std
::
vector
<
ck
::
index_t
>
a_gs_ms_ks_lengths
{
G0
,
G1
,
M
,
K
};
std
::
vector
<
ck
::
index_t
>
a_gs_ms_ks_lengths
{
G0
,
G1
Q
,
M
,
K
};
std
::
vector
<
ck
::
index_t
>
a_gs_ms_ks_strides
=
input_permute
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
*
K
,
K
,
G1
*
K
,
1
}
// A layout [G0, M, G1, K]
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
M
*
K
,
M
*
K
,
K
,
1
};
// A layout [G0, G1, M, K]
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
Q
*
K
,
K
,
G1
Q
*
K
,
1
}
// A layout [G0, M, G1
Q
, K]
:
std
::
vector
<
ck
::
index_t
>
{
G1
Q
*
M
*
K
,
M
*
K
,
K
,
1
};
// A layout [G0, G1
Q
, M, K]
std
::
vector
<
ck
::
index_t
>
b0_gs_ns_ks_lengths
{
G0
,
G1
,
N
,
K
};
std
::
vector
<
ck
::
index_t
>
b0_gs_ns_ks_lengths
{
G0
,
G1
KV
,
N
,
K
};
std
::
vector
<
ck
::
index_t
>
b0_gs_ns_ks_strides
=
input_permute
?
std
::
vector
<
ck
::
index_t
>
{
N
*
G1
*
K
,
K
,
G1
*
K
,
1
}
// B0 layout [G0, N, G1, K]
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
N
*
K
,
N
*
K
,
K
,
1
};
// B0 layout [G0, G1, N, K]
?
std
::
vector
<
ck
::
index_t
>
{
N
*
G1KV
*
K
,
K
,
G1KV
*
K
,
1
}
// B0 layout [G0, N, G1KV, K]
:
std
::
vector
<
ck
::
index_t
>
{
G1KV
*
N
*
K
,
N
*
K
,
K
,
1
};
// B0 layout [G0, G1KV, N, K]
std
::
vector
<
ck
::
index_t
>
b1_gs_os_ns_lengths
{
G0
,
G1
,
O
,
N
};
std
::
vector
<
ck
::
index_t
>
b1_gs_os_ns_lengths
{
G0
,
G1
KV
,
O
,
N
};
std
::
vector
<
ck
::
index_t
>
b1_gs_os_ns_strides
=
input_permute
?
std
::
vector
<
ck
::
index_t
>
{
N
*
G1
*
O
,
O
,
1
,
G1
*
O
}
// B1 layout [G0, N, G1, O]
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
N
*
O
,
N
*
O
,
1
,
O
};
// B1 layout [G0, G1, N, O]
?
std
::
vector
<
ck
::
index_t
>
{
N
*
G1KV
*
O
,
O
,
1
,
G1KV
*
O
}
// B1 layout [G0, N, G1KV, O]
:
std
::
vector
<
ck
::
index_t
>
{
G1KV
*
N
*
O
,
N
*
O
,
1
,
O
};
// B1 layout [G0, G1KV, N, O]
std
::
vector
<
ck
::
index_t
>
c_gs_ms_os_lengths
{
G0
,
G1
,
M
,
O
};
std
::
vector
<
ck
::
index_t
>
c_gs_ms_os_lengths
{
G0
,
G1
Q
,
M
,
O
};
std
::
vector
<
ck
::
index_t
>
c_gs_ms_os_strides
=
output_permute
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
*
O
,
O
,
G1
*
O
,
1
}
// C layout [G0, M, G1, O]
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
M
*
O
,
M
*
O
,
O
,
1
};
// C layout [G0, G1, M, O]
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
Q
*
O
,
O
,
G1
Q
*
O
,
1
}
// C layout [G0, M, G1
Q
, O]
:
std
::
vector
<
ck
::
index_t
>
{
G1
Q
*
M
*
O
,
M
*
O
,
O
,
1
};
// C layout [G0, G1
Q
, M, O]
std
::
vector
<
ck
::
index_t
>
z_gs_ms_ns_lengths
{
G0
,
G1
,
M
,
N
};
std
::
vector
<
ck
::
index_t
>
z_gs_ms_ns_lengths
{
G0
,
G1
Q
,
M
,
N
};
std
::
vector
<
ck
::
index_t
>
z_gs_ms_ns_strides
=
input_permute
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
*
N
,
N
,
G1
*
N
,
1
}
// Z layout [G0, M, G1, N]
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
M
*
N
,
M
*
N
,
N
,
1
};
// Z layout [G0, G1, M, N]
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
Q
*
N
,
N
,
G1
Q
*
N
,
1
}
// Z layout [G0, M, G1
Q
, N]
:
std
::
vector
<
ck
::
index_t
>
{
G1
Q
*
M
*
N
,
M
*
N
,
N
,
1
};
// Z layout [G0, G1
Q
, M, N]
std
::
vector
<
ck
::
index_t
>
lse_gs_ms_lengths
{
G0
,
G1
,
M
};
std
::
vector
<
ck
::
index_t
>
lse_gs_ms_lengths
{
G0
,
G1
Q
,
M
};
std
::
vector
<
ck
::
index_t
>
lse_gs_ms_strides
=
std
::
vector
<
ck
::
index_t
>
{
G1
*
M
,
M
,
1
};
// LSE layout [G0, G1, M]
std
::
vector
<
ck
::
index_t
>
{
G1
Q
*
M
,
M
,
1
};
// LSE layout [G0, G1
Q
, M]
problem_descs
.
push_back
({
a_gs_ms_ks_lengths
,
a_gs_ms_ks_strides
,
...
...
@@ -151,7 +158,7 @@ int run(int argc, char* argv[])
Tensor
<
ZDataType
>
z_gs_ms_ns
(
z_gs_ms_ns_lengths
,
z_gs_ms_ns_strides
);
Tensor
<
LSEDataType
>
lse_gs_ms_device_result
(
lse_gs_ms_lengths
,
lse_gs_ms_strides
);
int
Batch
=
G0
*
G1
;
int
Batch
=
G0
*
G1
Q
;
flop
+=
(
size_t
(
M
)
*
N
*
K
*
2
+
size_t
(
M
)
*
N
*
O
*
2
)
*
Batch
;
num_byte
+=
(
sizeof
(
ADataType
)
*
M
*
K
+
sizeof
(
B0DataType
)
*
K
*
N
+
sizeof
(
B1DataType
)
*
N
*
O
+
sizeof
(
CDataType
)
*
M
*
O
)
*
...
...
@@ -308,12 +315,12 @@ int run(int argc, char* argv[])
for
(
std
::
size_t
i
=
0
;
i
<
group_count
;
i
++
)
{
const
int
&
G0
=
g0_g1_m_n_k_o
[
i
][
0
];
const
int
&
G1
=
g0_g1_m_n_k_o
[
i
][
1
];
const
int
&
M
=
g0_g1_m_n_k_o
[
i
][
2
];
const
int
&
N
=
g0_g1_m_n_k_o
[
i
][
3
];
const
int
&
K
=
g0_g1_m_n_k_o
[
i
][
4
];
const
int
&
O
=
g0_g1_m_n_k_o
[
i
][
5
];
const
int
&
G0
=
g0_g1
q
_m_n_k_o
[
i
][
0
];
const
int
&
G1
Q
=
g0_g1
q
_m_n_k_o
[
i
][
1
];
const
int
&
M
=
g0_g1
q
_m_n_k_o
[
i
][
2
];
const
int
&
N
=
g0_g1
q
_m_n_k_o
[
i
][
3
];
const
int
&
K
=
g0_g1
q
_m_n_k_o
[
i
][
4
];
const
int
&
O
=
g0_g1
q
_m_n_k_o
[
i
][
5
];
const
auto
&
c_gs_ms_os_lengths
=
problem_descs
[
i
]
.
c_gs_ms_os_lengths
;
const
auto
&
c_gs_ms_os_strides
=
problem_descs
[
i
]
.
c_gs_ms_os_strides
;
...
...
@@ -334,31 +341,39 @@ int run(int argc, char* argv[])
z_gs_ms_ns_device_buf
.
FromDevice
(
z_gs_ms_ns_device_result
.
mData
.
data
());
lse_gs_ms_device_buf
.
FromDevice
(
lse_gs_ms_device_result
.
mData
.
data
());
Tensor
<
ADataType
>
a_g_m_k
({
G0
*
G1
,
M
,
K
});
Tensor
<
B0DataType
>
b0_g_k_n
({
G0
*
G1
,
K
,
N
});
Tensor
<
B1DataType
>
b1_g_n_o
({
G0
*
G1
,
N
,
O
});
Tensor
<
AccDataType
>
acc0_g_m_n
({
G0
*
G1
,
M
,
N
});
// scratch object after gemm0
Tensor
<
ADataType
>
a1_g_m_n
({
G0
*
G1
,
M
,
N
});
// scratch object after softmax
Tensor
<
ADataType
>
a1_g_m_n_drop
({
G0
*
G1
,
M
,
N
});
// scratch object after softmax
Tensor
<
CDataType
>
c_g_m_o_host_result
({
G0
*
G1
,
M
,
O
});
// scratch object after gemm1
Tensor
<
ADataType
>
a_g_m_k
({
G0
*
G1
Q
,
M
,
K
});
Tensor
<
B0DataType
>
b0_g_k_n
({
G0
*
G1
Q
,
K
,
N
});
Tensor
<
B1DataType
>
b1_g_n_o
({
G0
*
G1
Q
,
N
,
O
});
Tensor
<
AccDataType
>
acc0_g_m_n
({
G0
*
G1
Q
,
M
,
N
});
// scratch object after gemm0
Tensor
<
ADataType
>
a1_g_m_n
({
G0
*
G1
Q
,
M
,
N
});
// scratch object after softmax
Tensor
<
ADataType
>
a1_g_m_n_drop
({
G0
*
G1
Q
,
M
,
N
});
// scratch object after softmax
Tensor
<
CDataType
>
c_g_m_o_host_result
({
G0
*
G1
Q
,
M
,
O
});
// scratch object after gemm1
Tensor
<
CDataType
>
c_gs_ms_os_host_result
(
c_gs_ms_os_lengths
,
c_gs_ms_os_strides
);
Tensor
<
ZDataType
>
z_g_m_n
({
G0
*
G1
,
M
,
N
});
Tensor
<
LSEDataType
>
lse_g_m_host_result
({
G0
*
G1
,
M
});
// scratch object after gemm1
Tensor
<
ZDataType
>
z_g_m_n
({
G0
*
G1
Q
,
M
,
N
});
Tensor
<
LSEDataType
>
lse_g_m_host_result
({
G0
*
G1
Q
,
M
});
// scratch object after gemm1
Tensor
<
LSEDataType
>
lse_gs_ms_host_result
(
lse_gs_ms_lengths
,
lse_gs_ms_strides
);
// permute
a_gs_ms_ks
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
a_g_m_k
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
a_g_m_k
(
idx
[
0
]
*
G1
Q
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
});
b0_gs_ns_ks
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
b0_g_k_n
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
3
],
idx
[
2
])
=
self
(
idx
);
b0_g_k_n
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
const
size_t
&
g0
=
idx
[
0
]
/
G1Q
;
const
size_t
&
g1q
=
idx
[
0
]
%
G1Q
;
const
size_t
&
g1kv
=
g1q
/
h_ratio
;
self
(
idx
)
=
b0_gs_ns_ks
(
g0
,
g1kv
,
idx
[
2
],
idx
[
1
]);
});
b1_gs_os_ns
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
b1_g_n_o
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
3
],
idx
[
2
])
=
self
(
idx
);
b1_g_n_o
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
const
size_t
&
g0
=
idx
[
0
]
/
G1Q
;
const
size_t
&
g1q
=
idx
[
0
]
%
G1Q
;
const
size_t
&
g1kv
=
g1q
/
h_ratio
;
self
(
idx
)
=
b1_gs_os_ns
(
g0
,
g1kv
,
idx
[
2
],
idx
[
1
]);
});
z_gs_ms_ns_device_result
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
z_g_m_n
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
z_g_m_n
(
idx
[
0
]
*
G1
Q
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
});
// gemm 0
...
...
@@ -408,18 +423,18 @@ int run(int argc, char* argv[])
// permute
c_gs_ms_os_host_result
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
const
size_t
&
g0
=
idx
[
0
];
const
size_t
&
g1
=
idx
[
1
];
const
size_t
&
g0
=
idx
[
0
];
const
size_t
&
g1
q
=
idx
[
1
];
const
size_t
g
=
g0
*
G1
+
g1
;
const
size_t
g
=
g0
*
G1
Q
+
g1
q
;
self
(
idx
)
=
c_g_m_o_host_result
(
g
,
idx
[
2
],
idx
[
3
]);
});
lse_gs_ms_host_result
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
const
size_t
&
g0
=
idx
[
0
];
const
size_t
&
g1
=
idx
[
1
];
const
size_t
&
g0
=
idx
[
0
];
const
size_t
&
g1
q
=
idx
[
1
];
const
size_t
g
=
g0
*
G1
+
g1
;
const
size_t
g
=
g0
*
G1
Q
+
g1
q
;
self
(
idx
)
=
lse_g_m_host_result
(
g
,
idx
[
2
]);
});
...
...
example/52_flash_atten_bias/batched_multihead_attention_bias_backward_v2.cpp
View file @
e87ddb0e
...
...
@@ -273,14 +273,15 @@ int run(int argc, char* argv[])
// Overall QKV matrices shape
// y_g_m_o = Softmax(alpha * Q_g_m_k * K_g_k_n) * V_g_n_o
// y_g0_g1_m_o = reshape(y_g_m_o, [G0, G1, M, O])
// y_g0_m_g1_o = permute(y_g0_g1_m_o, [0, 2, 1, 3])
ck
::
index_t
M
=
512
;
ck
::
index_t
N
=
512
;
ck
::
index_t
K
=
DIM
;
ck
::
index_t
O
=
DIM
;
ck
::
index_t
G0
=
4
;
ck
::
index_t
G1
=
6
;
// y_g0_g1q_m_o = reshape(y_g_m_o, [G0, G1Q, M, O])
// y_g0_m_g1q_o = permute(y_g0_g1q_m_o, [0, 2, 1, 3])
ck
::
index_t
M
=
512
;
ck
::
index_t
N
=
512
;
ck
::
index_t
K
=
DIM
;
ck
::
index_t
O
=
DIM
;
ck
::
index_t
G0
=
4
;
ck
::
index_t
G1Q
=
6
;
// h_q
ck
::
index_t
G1KV
=
6
;
// h_kv
bool
input_permute
=
false
;
bool
output_permute
=
false
;
...
...
@@ -299,32 +300,33 @@ int run(int argc, char* argv[])
init_method
=
std
::
stoi
(
argv
[
2
]);
time_kernel
=
std
::
stoi
(
argv
[
3
]);
}
else
if
(
argc
==
1
3
)
else
if
(
argc
==
1
4
)
{
do_verification
=
std
::
stoi
(
argv
[
1
]);
init_method
=
std
::
stoi
(
argv
[
2
]);
time_kernel
=
std
::
stoi
(
argv
[
3
]);
M
=
std
::
stoi
(
argv
[
4
]);
N
=
std
::
stoi
(
argv
[
5
]);
K
=
std
::
stoi
(
argv
[
6
]);
O
=
std
::
stoi
(
argv
[
7
]);
G0
=
std
::
stoi
(
argv
[
8
]);
G1
=
std
::
stoi
(
argv
[
9
]);
M
=
std
::
stoi
(
argv
[
4
]);
N
=
std
::
stoi
(
argv
[
5
]);
K
=
std
::
stoi
(
argv
[
6
]);
O
=
std
::
stoi
(
argv
[
7
]);
G0
=
std
::
stoi
(
argv
[
8
]);
G1Q
=
std
::
stoi
(
argv
[
9
]);
G1KV
=
std
::
stoi
(
argv
[
10
]);
p_drop
=
std
::
stof
(
argv
[
1
0
]);
p_drop
=
std
::
stof
(
argv
[
1
1
]);
input_permute
=
std
::
stoi
(
argv
[
1
1
]);
output_permute
=
std
::
stoi
(
argv
[
1
2
]);
input_permute
=
std
::
stoi
(
argv
[
1
2
]);
output_permute
=
std
::
stoi
(
argv
[
1
3
]);
}
else
{
printf
(
"arg1: verification (0=no, 1=yes)
\n
"
);
printf
(
"arg2: initialization (0=no init, 1=integer value, 2=decimal value)
\n
"
);
printf
(
"arg3: time kernel (0=no, 1=yes)
\n
"
);
printf
(
"arg4 to 1
1
: M, N, K, O, G0, G1
\n
"
);
printf
(
"arg1
0
:
scale (alpha)
\n
"
);
printf
(
"arg1
1
to 1
2
: input / output permute
\n
"
);
printf
(
"arg4 to 1
0
: M, N, K, O, G0, G1
Q, G1KV
\n
"
);
printf
(
"arg1
1
:
p_drop
\n
"
);
printf
(
"arg1
2
to 1
3
: input / output permute
\n
"
);
exit
(
0
);
}
...
...
@@ -341,7 +343,8 @@ int run(int argc, char* argv[])
std
::
cout
<<
"K: "
<<
K
<<
std
::
endl
;
std
::
cout
<<
"O: "
<<
O
<<
std
::
endl
;
std
::
cout
<<
"G0: "
<<
G0
<<
std
::
endl
;
std
::
cout
<<
"G1: "
<<
G1
<<
std
::
endl
;
std
::
cout
<<
"G1Q: "
<<
G1Q
<<
std
::
endl
;
std
::
cout
<<
"G1KV: "
<<
G1KV
<<
std
::
endl
;
std
::
cout
<<
"alpha: "
<<
alpha
<<
std
::
endl
;
std
::
cout
<<
"input_permute: "
<<
input_permute
<<
std
::
endl
;
std
::
cout
<<
"output_permute: "
<<
output_permute
<<
std
::
endl
;
...
...
@@ -349,51 +352,63 @@ int run(int argc, char* argv[])
std
::
cout
<<
"seed: "
<<
seed
<<
std
::
endl
;
std
::
cout
<<
"offset: "
<<
offset
<<
std
::
endl
;
const
ck
::
index_t
BatchCount
=
G0
*
G1
;
const
ck
::
index_t
BatchCount
=
G0
*
G1
Q
;
std
::
vector
<
ck
::
index_t
>
q_gs_ms_ks_lengths
{
G0
,
G1
,
M
,
K
};
std
::
vector
<
ck
::
index_t
>
q_gs_ms_ks_lengths
{
G0
,
G1
Q
,
M
,
K
};
std
::
vector
<
ck
::
index_t
>
q_gs_ms_ks_strides
=
input_permute
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
*
K
,
K
,
G1
*
K
,
1
}
// Q layout [G0, M, G1, K]
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
M
*
K
,
M
*
K
,
K
,
1
};
// Q layout [G0, G1, M, K]
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
Q
*
K
,
K
,
G1
Q
*
K
,
1
}
// Q layout [G0, M, G1
Q
, K]
:
std
::
vector
<
ck
::
index_t
>
{
G1
Q
*
M
*
K
,
M
*
K
,
K
,
1
};
// Q layout [G0, G1
Q
, M, K]
std
::
vector
<
ck
::
index_t
>
k_gs_ns_ks_lengths
{
G0
,
G1
,
N
,
K
};
std
::
vector
<
ck
::
index_t
>
k_gs_ns_ks_lengths
{
G0
,
G1
KV
,
N
,
K
};
std
::
vector
<
ck
::
index_t
>
k_gs_ns_ks_strides
=
input_permute
?
std
::
vector
<
ck
::
index_t
>
{
N
*
G1
*
K
,
K
,
G1
*
K
,
1
}
// K layout [G0, N, G1, K]
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
N
*
K
,
N
*
K
,
K
,
1
};
// K layout [G0, G1, N, K]
?
std
::
vector
<
ck
::
index_t
>
{
N
*
G1
KV
*
K
,
K
,
G1
KV
*
K
,
1
}
// K layout [G0, N, G1
KV
, K]
:
std
::
vector
<
ck
::
index_t
>
{
G1
KV
*
N
*
K
,
N
*
K
,
K
,
1
};
// K layout [G0, G1
KV
, N, K]
std
::
vector
<
ck
::
index_t
>
v_gs_os_ns_lengths
{
G0
,
G1
,
O
,
N
};
std
::
vector
<
ck
::
index_t
>
v_gs_os_ns_lengths
{
G0
,
G1
KV
,
O
,
N
};
std
::
vector
<
ck
::
index_t
>
v_gs_os_ns_strides
=
input_permute
?
std
::
vector
<
ck
::
index_t
>
{
N
*
G1
*
O
,
O
,
1
,
G1
*
O
}
// V layout [G0, N, G1, O]
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
N
*
O
,
N
*
O
,
1
,
O
};
// V layout [G0, G1, N, O]
?
std
::
vector
<
ck
::
index_t
>
{
N
*
G1
KV
*
O
,
O
,
1
,
G1
KV
*
O
}
// V layout [G0, N, G1
KV
, O]
:
std
::
vector
<
ck
::
index_t
>
{
G1
KV
*
N
*
O
,
N
*
O
,
1
,
O
};
// V layout [G0, G1
KV
, N, O]
std
::
vector
<
ck
::
index_t
>
y_gs_ms_os_lengths
{
G0
,
G1
,
M
,
O
};
std
::
vector
<
ck
::
index_t
>
y_gs_ms_os_lengths
{
G0
,
G1
Q
,
M
,
O
};
std
::
vector
<
ck
::
index_t
>
y_gs_ms_os_strides
=
output_permute
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
*
O
,
O
,
G1
*
O
,
1
}
// Y layout [G0, M, G1, O]
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
M
*
O
,
M
*
O
,
O
,
1
};
// Y layout [G0, G1, M, O]
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
Q
*
O
,
O
,
G1
Q
*
O
,
1
}
// Y layout [G0, M, G1
Q
, O]
:
std
::
vector
<
ck
::
index_t
>
{
G1
Q
*
M
*
O
,
M
*
O
,
O
,
1
};
// Y layout [G0, G1
Q
, M, O]
std
::
vector
<
ck
::
index_t
>
d0_gs_ms_ns_lengths
{
G0
,
G1
,
M
,
N
};
std
::
vector
<
ck
::
index_t
>
d0_gs_ms_ns_lengths
{
G0
,
G1
Q
,
M
,
N
};
std
::
vector
<
ck
::
index_t
>
d0_gs_ms_ns_strides
=
input_permute
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
*
N
,
N
,
G1
*
N
,
1
}
// D layout [G0, M, G1, N]
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
M
*
N
,
M
*
N
,
N
,
1
};
// D layout [G0, G1, M, N]
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
Q
*
N
,
N
,
G1
Q
*
N
,
1
}
// D layout [G0, M, G1
Q
, N]
:
std
::
vector
<
ck
::
index_t
>
{
G1
Q
*
M
*
N
,
M
*
N
,
N
,
1
};
// D layout [G0, G1
Q
, M, N]
std
::
vector
<
ck
::
index_t
>
z_gs_ms_ns_lengths
{
G0
,
G1
,
M
,
N
};
std
::
vector
<
ck
::
index_t
>
z_gs_ms_ns_lengths
{
G0
,
G1
Q
,
M
,
N
};
std
::
vector
<
ck
::
index_t
>
z_gs_ms_ns_strides
=
input_permute
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
*
N
,
N
,
G1
*
N
,
1
}
// Z layout [G0, M, G1, N]
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
M
*
N
,
M
*
N
,
N
,
1
};
// Z layout [G0, G1, M, N]
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1Q
*
N
,
N
,
G1Q
*
N
,
1
}
// Z layout [G0, M, G1Q, N]
:
std
::
vector
<
ck
::
index_t
>
{
G1Q
*
M
*
N
,
M
*
N
,
N
,
1
};
// Z layout [G0, G1Q, M, N]
std
::
vector
<
ck
::
index_t
>
kgrad_gs_ns_ks_lengths
{
G0
,
G1Q
,
N
,
K
};
std
::
vector
<
ck
::
index_t
>
kgrad_gs_ns_ks_strides
=
input_permute
?
std
::
vector
<
ck
::
index_t
>
{
N
*
G1Q
*
K
,
K
,
G1Q
*
K
,
1
}
// KGrad layout [G0, N, G1Q, K]
:
std
::
vector
<
ck
::
index_t
>
{
G1Q
*
N
*
K
,
N
*
K
,
K
,
1
};
// KGrad layout [G0, G1Q, N, K]
std
::
vector
<
ck
::
index_t
>
vgrad_gs_os_ns_lengths
{
G0
,
G1Q
,
O
,
N
};
std
::
vector
<
ck
::
index_t
>
vgrad_gs_os_ns_strides
=
input_permute
?
std
::
vector
<
ck
::
index_t
>
{
N
*
G1Q
*
O
,
O
,
1
,
G1Q
*
O
}
// VGrad layout [G0, N, G1Q, O]
:
std
::
vector
<
ck
::
index_t
>
{
G1Q
*
N
*
O
,
N
*
O
,
1
,
O
};
// VGrad layout [G0, G1Q, N, O]
// The softmax stat log-sum-exp (LSE) is used to speed up softmax calculation in backward pass
// Pi = exp(Si) / sum(exp(S0) + exp(S1) + ...)
// = exp(Si) / exp(log(sum(exp() + ...)))
// = exp(Si - log(sum(exp() + ...)))
// ^^^^^^^^^^^^^^^^^^^^^
// LSE
std
::
vector
<
ck
::
index_t
>
lse_gs_ms_lengths
{
G0
,
G1
,
M
};
std
::
vector
<
ck
::
index_t
>
lse_gs_ms_strides
{
G1
*
M
,
M
,
1
};
// LSE layout [G0, G1, M]
std
::
vector
<
ck
::
index_t
>
lse_gs_ms_lengths
{
G0
,
G1
Q
,
M
};
std
::
vector
<
ck
::
index_t
>
lse_gs_ms_strides
{
G1
Q
*
M
,
M
,
1
};
// LSE layout [G0, G1
Q
, M]
Tensor
<
InputDataType
>
q_gs_ms_ks
(
q_gs_ms_ks_lengths
,
q_gs_ms_ks_strides
);
Tensor
<
InputDataType
>
k_gs_ns_ks
(
k_gs_ns_ks_lengths
,
k_gs_ns_ks_strides
);
...
...
@@ -403,6 +418,8 @@ int run(int argc, char* argv[])
Tensor
<
InputDataType
>
y_gs_ms_os
(
y_gs_ms_os_lengths
,
y_gs_ms_os_strides
);
Tensor
<
InputDataType
>
ygrad_gs_ms_os
(
y_gs_ms_os_lengths
,
y_gs_ms_os_strides
);
Tensor
<
LSEDataType
>
lse_gs_ms
(
lse_gs_ms_lengths
,
lse_gs_ms_strides
);
Tensor
<
OutputDataType
>
kgrad_gs_ns_ks
(
kgrad_gs_ns_ks_lengths
,
kgrad_gs_ns_ks_strides
);
Tensor
<
OutputDataType
>
vgrad_gs_os_ns
(
vgrad_gs_os_ns_lengths
,
vgrad_gs_os_ns_strides
);
std
::
cout
<<
"q_gs_ms_ks: "
<<
q_gs_ms_ks
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"k_gs_ns_ks: "
<<
k_gs_ns_ks
.
mDesc
<<
std
::
endl
;
...
...
@@ -411,6 +428,8 @@ int run(int argc, char* argv[])
std
::
cout
<<
"v_gs_os_ns: "
<<
v_gs_os_ns
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"y_gs_ms_os: "
<<
y_gs_ms_os
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"lse_gs_ms_os: "
<<
lse_gs_ms
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"kgrad_gs_ns_ks: "
<<
kgrad_gs_ns_ks
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"vgrad_gs_os_ns: "
<<
vgrad_gs_os_ns
.
mDesc
<<
std
::
endl
;
z_gs_ms_ns
.
GenerateTensorValue
(
GeneratorTensor_1
<
InputDataType
>
{
0
});
switch
(
init_method
)
...
...
@@ -448,7 +467,7 @@ int run(int argc, char* argv[])
q_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_1
<
InputDataType
>
{
1
});
k_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
InputDataType
>
{});
v_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
InputDataType
>
{});
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_Sequential
<
2
>
{});
// dy[g0, g1, m, o]
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_Sequential
<
2
>
{});
// dy[g0, g1
q
, m, o]
d0_gs_ms_ns
.
GenerateTensorValue
(
GeneratorTensor_1
<
Acc0BiasDataType
>
{
1
});
// dO dot O = [0; 1; 2; ...]
break
;
...
...
@@ -456,7 +475,7 @@ int run(int argc, char* argv[])
q_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_1
<
InputDataType
>
{
1
});
k_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
InputDataType
>
{});
v_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
InputDataType
>
{});
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_Sequential
<
3
>
{});
// dy[g0, g1, m, o]
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_Sequential
<
3
>
{});
// dy[g0, g1
q
, m, o]
d0_gs_ms_ns
.
GenerateTensorValue
(
GeneratorTensor_1
<
Acc0BiasDataType
>
{
1
});
// assume mnko = 256
// P = softmax(QK) = 0.0039 * ones
...
...
@@ -470,7 +489,8 @@ int run(int argc, char* argv[])
q_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_1
<
InputDataType
>
{
1
});
k_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
InputDataType
>
{});
v_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
InputDataType
>
{});
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_1
<
InputDataType
>
{
1
});
// dy[g0,g1, m, o]
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_1
<
InputDataType
>
{
1
});
// dy[g0, g1q, m, o]
d0_gs_ms_ns
.
GenerateTensorValue
(
GeneratorTensor_1
<
Acc0BiasDataType
>
{
1
});
// assume mnko = 256
// P = softmax(QK) = 0.0039 * ones
...
...
@@ -491,8 +511,8 @@ int run(int argc, char* argv[])
DeviceMem
y_device_buf
(
sizeof
(
InputDataType
)
*
y_gs_ms_os
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
lse_device_buf
(
sizeof
(
LSEDataType
)
*
lse_gs_ms
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
qgrad_device_buf
(
sizeof
(
OutputDataType
)
*
q_gs_ms_ks
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
kgrad_device_buf
(
sizeof
(
OutputDataType
)
*
k_gs_ns_ks
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
vgrad_device_buf
(
sizeof
(
OutputDataType
)
*
v_gs_os_ns
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
kgrad_device_buf
(
sizeof
(
OutputDataType
)
*
k
grad
_gs_ns_ks
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
vgrad_device_buf
(
sizeof
(
OutputDataType
)
*
v
grad
_gs_os_ns
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
ygrad_device_buf
(
sizeof
(
InputDataType
)
*
y_gs_ms_os
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
d0grad_device_buf
(
sizeof
(
Acc0BiasDataType
)
*
d0_gs_ms_ns
.
mDesc
.
GetElementSpaceSize
());
...
...
@@ -533,6 +553,10 @@ int run(int argc, char* argv[])
y_gs_ms_os_lengths
,
y_gs_ms_os_strides
,
lse_gs_ms_lengths
,
kgrad_gs_ns_ks_lengths
,
kgrad_gs_ns_ks_strides
,
vgrad_gs_os_ns_lengths
,
vgrad_gs_os_ns_strides
,
d0_gs_ms_ns_lengths
,
// acc0_bias_gs_ms_ns_lengths
d0_gs_ms_ns_strides
,
// acc0_bias_gs_ms_ns_strides
{},
// acc1_bias_gs_ms_os_lengths,
...
...
@@ -580,6 +604,10 @@ int run(int argc, char* argv[])
y_gs_ms_os_lengths
,
y_gs_ms_os_strides
,
lse_gs_ms_lengths
,
kgrad_gs_ns_ks_lengths
,
kgrad_gs_ns_ks_strides
,
vgrad_gs_os_ns_lengths
,
vgrad_gs_os_ns_strides
,
d0_gs_ms_ns_lengths
,
// acc0_bias_gs_ms_ns_lengths
d0_gs_ms_ns_strides
,
// acc0_bias_gs_ms_ns_strides
{},
// acc1_bias_gs_ms_os_lengths,
...
...
@@ -624,7 +652,7 @@ int run(int argc, char* argv[])
Tensor
<
InputDataType
>
q_g_m_k
({
BatchCount
,
M
,
K
});
Tensor
<
InputDataType
>
k_g_n_k
({
BatchCount
,
N
,
K
});
Tensor
<
Acc0BiasDataType
>
d0_g_m_n
({
G0
*
G1
,
M
,
N
});
Tensor
<
Acc0BiasDataType
>
d0_g_m_n
({
BatchCount
,
M
,
N
});
Tensor
<
ZDataType
>
z_g_m_n
({
BatchCount
,
M
,
N
});
Tensor
<
InputDataType
>
v_g_n_o
({
BatchCount
,
N
,
O
});
Tensor
<
AccDataType
>
s_g_m_n
({
BatchCount
,
M
,
N
});
...
...
@@ -635,19 +663,27 @@ int run(int argc, char* argv[])
z_device_buf
.
FromDevice
(
z_gs_ms_ns
.
mData
.
data
());
z_gs_ms_ns
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
z_g_m_n
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
z_g_m_n
(
idx
[
0
]
*
G1
Q
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
});
q_gs_ms_ks
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
q_g_m_k
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
q_g_m_k
(
idx
[
0
]
*
G1
Q
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
});
k_gs_ns_ks
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
k_g_n_k
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
k_g_n_k
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
const
size_t
&
g0
=
idx
[
0
]
/
G1Q
;
const
size_t
&
g1q
=
idx
[
0
]
%
G1Q
;
const
size_t
&
g1kv
=
g1q
/
(
G1Q
/
G1KV
);
self
(
idx
)
=
k_gs_ns_ks
(
g0
,
g1kv
,
idx
[
1
],
idx
[
2
]);
});
v_gs_os_ns
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
v_g_n_o
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
3
],
idx
[
2
])
=
self
(
idx
);
v_g_n_o
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
const
size_t
&
g0
=
idx
[
0
]
/
G1Q
;
const
size_t
&
g1q
=
idx
[
0
]
%
G1Q
;
const
size_t
&
g1kv
=
g1q
/
(
G1Q
/
G1KV
);
self
(
idx
)
=
v_gs_os_ns
(
g0
,
g1kv
,
idx
[
2
],
idx
[
1
]);
});
d0_gs_ms_ns
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
d0_g_m_n
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
d0_g_m_n
(
idx
[
0
]
*
G1
Q
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
});
// run fwd again for y, cause z_g_m_n update
run_attention_fwd_host
(
q_g_m_k
,
...
...
@@ -664,10 +700,10 @@ int run(int argc, char* argv[])
p_dropout_in_uint8_t
,
rp_dropout
);
y_gs_ms_os
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
self
(
idx
)
=
y_g_m_o
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
]);
self
(
idx
)
=
y_g_m_o
(
idx
[
0
]
*
G1
Q
+
idx
[
1
],
idx
[
2
],
idx
[
3
]);
});
lse_gs_ms
.
ForEach
(
[
&
](
auto
&
self
,
auto
idx
)
{
self
(
idx
)
=
lse_g_m
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
]);
});
[
&
](
auto
&
self
,
auto
idx
)
{
self
(
idx
)
=
lse_g_m
(
idx
[
0
]
*
G1
Q
+
idx
[
1
],
idx
[
2
]);
});
y_device_buf
.
ToDevice
(
y_gs_ms_os
.
mData
.
data
());
lse_device_buf
.
ToDevice
(
lse_gs_ms
.
mData
.
data
());
...
...
@@ -685,7 +721,7 @@ int run(int argc, char* argv[])
Tensor
<
InputDataType
>
ygrad_dot_y_g_m
({
BatchCount
,
M
});
ygrad_gs_ms_os
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
ygrad_g_m_o
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
ygrad_g_m_o
(
idx
[
0
]
*
G1
Q
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
});
#if PRINT_HOST
...
...
@@ -787,14 +823,18 @@ int run(int argc, char* argv[])
#endif
Tensor
<
OutputDataType
>
qgrad_gs_ms_ks_host_result
(
q_gs_ms_ks_lengths
,
q_gs_ms_ks_strides
);
Tensor
<
OutputDataType
>
kgrad_gs_ns_ks_host_result
(
k_gs_ns_ks_lengths
,
k_gs_ns_ks_strides
);
Tensor
<
OutputDataType
>
vgrad_gs_os_ns_host_result
(
v_gs_os_ns_lengths
,
v_gs_os_ns_strides
);
Tensor
<
OutputDataType
>
kgrad_gs_ns_ks_host_result
(
kgrad_gs_ns_ks_lengths
,
kgrad_gs_ns_ks_strides
);
Tensor
<
OutputDataType
>
vgrad_gs_os_ns_host_result
(
vgrad_gs_os_ns_lengths
,
vgrad_gs_os_ns_strides
);
Tensor
<
Acc0BiasDataType
>
d0grad_gs_ms_ns_host_result
(
d0_gs_ms_ns_lengths
,
d0_gs_ms_ns_strides
);
Tensor
<
OutputDataType
>
qgrad_gs_ms_ks_device_result
(
q_gs_ms_ks_lengths
,
q_gs_ms_ks_strides
);
Tensor
<
OutputDataType
>
kgrad_gs_ns_ks_device_result
(
k_gs_ns_ks_lengths
,
k_gs_ns_ks_strides
);
Tensor
<
OutputDataType
>
vgrad_gs_os_ns_device_result
(
v_gs_os_ns_lengths
,
v_gs_os_ns_strides
);
Tensor
<
OutputDataType
>
kgrad_gs_ns_ks_device_result
(
kgrad_gs_ns_ks_lengths
,
kgrad_gs_ns_ks_strides
);
Tensor
<
OutputDataType
>
vgrad_gs_os_ns_device_result
(
vgrad_gs_os_ns_lengths
,
vgrad_gs_os_ns_strides
);
Tensor
<
Acc0BiasDataType
>
d0grad_gs_ms_ns_device_result
(
d0_gs_ms_ns_lengths
,
d0_gs_ms_ns_strides
);
...
...
@@ -805,35 +845,35 @@ int run(int argc, char* argv[])
// permute
qgrad_gs_ms_ks_host_result
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
const
size_t
&
g0
=
idx
[
0
];
const
size_t
&
g1
=
idx
[
1
];
const
size_t
&
g0
=
idx
[
0
];
const
size_t
&
g1
q
=
idx
[
1
];
const
size_t
g
=
g0
*
G1
+
g1
;
const
size_t
g
=
g0
*
G1
Q
+
g1
q
;
self
(
idx
)
=
qgrad_g_m_k
(
g
,
idx
[
2
],
idx
[
3
]);
});
kgrad_gs_ns_ks_host_result
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
const
size_t
&
g0
=
idx
[
0
];
const
size_t
&
g1
=
idx
[
1
];
const
size_t
&
g0
=
idx
[
0
];
const
size_t
&
g1
q
=
idx
[
1
];
const
size_t
g
=
g0
*
G1
+
g1
;
const
size_t
g
=
g0
*
G1
Q
+
g1
q
;
self
(
idx
)
=
kgrad_g_n_k
(
g
,
idx
[
2
],
idx
[
3
]);
});
vgrad_gs_os_ns_host_result
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
const
size_t
&
g0
=
idx
[
0
];
const
size_t
&
g1
=
idx
[
1
];
const
size_t
&
g0
=
idx
[
0
];
const
size_t
&
g1
q
=
idx
[
1
];
const
size_t
g
=
g0
*
G1
+
g1
;
const
size_t
g
=
g0
*
G1
Q
+
g1
q
;
self
(
idx
)
=
vgrad_g_n_o
(
g
,
idx
[
3
],
idx
[
2
]);
});
d0grad_gs_ms_ns_host_result
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
const
size_t
&
g0
=
idx
[
0
];
const
size_t
&
g1
=
idx
[
1
];
const
size_t
&
g0
=
idx
[
0
];
const
size_t
&
g1
q
=
idx
[
1
];
const
size_t
g
=
g0
*
G1
+
g1
;
const
size_t
g
=
g0
*
G1
Q
+
g1
q
;
self
(
idx
)
=
sgrad_g_m_n
(
g
,
idx
[
2
],
idx
[
3
]);
});
...
...
example/52_flash_atten_bias/batched_multihead_attention_bias_forward_v2.cpp
View file @
e87ddb0e
...
...
@@ -71,11 +71,10 @@ static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecializatio
static
constexpr
auto
MaskingSpec
=
ck
::
tensor_operation
::
device
::
MaskingSpecialization
::
MaskDisabled
;
static
constexpr
auto
TensorSpecA
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecB0
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecB1
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecC
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
bool
Deterministic
=
false
;
static
constexpr
auto
TensorSpecA
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecB0
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecB1
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecC
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
#if(DIM <= 32)
using
DeviceGemmInstance
=
...
...
@@ -149,8 +148,7 @@ using DeviceGemmInstance =
S
<
1
,
64
,
1
,
4
>
,
// CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8
,
// CShuffleBlockTransferScalarPerVector_NPerBlock
4
,
MaskingSpec
,
// MaskingSpecialization
Deterministic
>
;
MaskingSpec
>
;
// MaskingSpecialization
#elif(DIM <= 64)
using
DeviceGemmInstance
=
ck
::
tensor_operation
::
device
::
DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
<
...
...
@@ -223,8 +221,7 @@ using DeviceGemmInstance =
S
<
1
,
32
,
1
,
8
>
,
// CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8
,
// CShuffleBlockTransferScalarPerVector_NPerBlock
4
,
MaskingSpec
,
// MaskingSpecialization
Deterministic
>
;
MaskingSpec
>
;
// MaskingSpecialization
#elif(DIM <= 128)
using
DeviceGemmInstance
=
ck
::
tensor_operation
::
device
::
DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
<
...
...
@@ -297,8 +294,7 @@ using DeviceGemmInstance =
S
<
1
,
32
,
1
,
8
>
,
// CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8
,
// CShuffleBlockTransferScalarPerVector_NPerBlock
4
,
MaskingSpec
,
// MaskingSpecialization
Deterministic
>
;
MaskingSpec
>
;
// MaskingSpecialization
#endif
// Ref Gemm0: DataType in, AccDataType out
...
...
example/52_flash_atten_bias/batched_multihead_attention_bias_forward_v2_zcheck.cpp
View file @
e87ddb0e
...
...
@@ -71,11 +71,10 @@ static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecializatio
static
constexpr
auto
MaskingSpec
=
ck
::
tensor_operation
::
device
::
MaskingSpecialization
::
MaskDisabled
;
static
constexpr
auto
TensorSpecA
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecB0
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecB1
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecC
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
bool
Deterministic
=
false
;
static
constexpr
auto
TensorSpecA
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecB0
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecB1
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecC
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
#if(DIM <= 32)
using
DeviceGemmInstance
=
...
...
@@ -149,8 +148,7 @@ using DeviceGemmInstance =
S
<
1
,
64
,
1
,
4
>
,
// CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8
,
// CShuffleBlockTransferScalarPerVector_NPerBlock
4
,
MaskingSpec
,
// MaskingSpecialization
Deterministic
>
;
MaskingSpec
>
;
// MaskingSpecialization
#elif(DIM <= 64)
using
DeviceGemmInstance
=
ck
::
tensor_operation
::
device
::
DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
<
...
...
@@ -223,8 +221,7 @@ using DeviceGemmInstance =
S
<
1
,
32
,
1
,
8
>
,
// CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8
,
// CShuffleBlockTransferScalarPerVector_NPerBlock
4
,
MaskingSpec
,
// MaskingSpecialization
Deterministic
>
;
MaskingSpec
>
;
// MaskingSpecialization
#elif(DIM <= 128)
using
DeviceGemmInstance
=
ck
::
tensor_operation
::
device
::
DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
<
...
...
@@ -297,8 +294,7 @@ using DeviceGemmInstance =
S
<
1
,
32
,
1
,
8
>
,
// CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8
,
// CShuffleBlockTransferScalarPerVector_NPerBlock
4
,
MaskingSpec
,
// MaskingSpecialization
Deterministic
>
;
MaskingSpec
>
;
// MaskingSpecialization
#endif
using
DeviceDropoutInstance
=
ck
::
tensor_operation
::
device
::
DeviceBatchedDropout
<
NumDimG
,
...
...
example/52_flash_atten_bias/grouped_multihead_attention_bias_backward_v2.cpp
View file @
e87ddb0e
...
...
@@ -24,7 +24,7 @@ Kernel outputs:
*/
#define USING_MASK 0
#define DIM
128
// DIM should be a multiple of 8.
#define DIM
64
// DIM should be a multiple of 8.
#include <iostream>
#include <numeric>
...
...
@@ -271,10 +271,11 @@ int run(int argc, char* argv[])
// Overall QKV matrices shape
// y_g_m_o = Softmax(alpha * Q_g_m_k * K_g_k_n) * V_g_n_o
// y_g0_g1_m_o = reshape(y_g_m_o, [G0, G1, M, O])
// y_g0_m_g1_o = permute(y_g0_g1_m_o, [0, 2, 1, 3])
// y_g0_g1
q
_m_o = reshape(y_g_m_o, [G0, G1
Q
, M, O])
// y_g0_m_g1
q
_o = permute(y_g0_g1
q
_m_o, [0, 2, 1, 3])
float
alpha
=
1.
f
/
std
::
sqrt
(
DIM
);
float
p_drop
=
0.0
;
int
h_ratio
=
1
;
// G1Q / G1KV
bool
input_permute
=
true
;
bool
output_permute
=
true
;
...
...
@@ -292,25 +293,26 @@ int run(int argc, char* argv[])
init_method
=
std
::
stoi
(
argv
[
2
]);
time_kernel
=
std
::
stoi
(
argv
[
3
]);
}
else
if
(
argc
==
7
)
else
if
(
argc
==
8
)
{
do_verification
=
std
::
stoi
(
argv
[
1
]);
init_method
=
std
::
stoi
(
argv
[
2
]);
time_kernel
=
std
::
stoi
(
argv
[
3
]);
p_drop
=
std
::
stof
(
argv
[
4
]);
p_drop
=
std
::
stof
(
argv
[
4
]);
h_ratio
=
std
::
stof
(
argv
[
5
]);
input_permute
=
std
::
stoi
(
argv
[
5
]);
output_permute
=
std
::
stoi
(
argv
[
6
]);
input_permute
=
std
::
stoi
(
argv
[
6
]);
output_permute
=
std
::
stoi
(
argv
[
7
]);
}
else
{
printf
(
"arg1: verification (0=no, 1=yes)
\n
"
);
printf
(
"arg2: initialization (0=no init, 1=integer value, 2=decimal value)
\n
"
);
printf
(
"arg3: time kernel (0=no, 1=yes)
\n
"
);
printf
(
"arg4
to 11: M, N, K, O, G0, G1
\n
"
);
printf
(
"arg
10: scale (alpha)
\n
"
);
printf
(
"arg
11
to
12
: input / output permute
\n
"
);
printf
(
"arg4
: p_drop
\n
"
);
printf
(
"arg
5: h_ratio
\n
"
);
printf
(
"arg
6
to
7
: input / output permute
\n
"
);
exit
(
0
);
}
...
...
@@ -377,55 +379,71 @@ int run(int argc, char* argv[])
std
::
size_t
flop
=
0
,
num_byte
=
0
;
for
(
std
::
size_t
i
=
0
;
i
<
group_count
;
i
++
)
{
int
M
=
128
*
(
rand
()
%
8
)
+
(
rand
()
%
128
);
int
N
=
128
*
(
rand
()
%
8
)
+
(
rand
()
%
128
);
int
K
=
DIM
;
int
O
=
DIM
;
int
G0
=
rand
()
%
4
+
1
;
int
G1
=
rand
()
%
4
+
1
;
std
::
vector
<
ck
::
index_t
>
q_gs_ms_ks_lengths
{
G0
,
G1
,
M
,
K
};
int
M
=
128
*
(
rand
()
%
8
)
+
(
rand
()
%
128
);
int
N
=
128
*
(
rand
()
%
8
)
+
(
rand
()
%
128
);
int
K
=
DIM
;
int
O
=
DIM
;
int
G0
=
rand
()
%
4
+
1
;
int
G1KV
=
rand
()
%
4
+
1
;
int
G1Q
=
G1KV
*
h_ratio
;
std
::
vector
<
ck
::
index_t
>
q_gs_ms_ks_lengths
{
G0
,
G1Q
,
M
,
K
};
std
::
vector
<
ck
::
index_t
>
q_gs_ms_ks_strides
=
input_permute
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
*
K
,
K
,
G1
*
K
,
1
}
// Q layout [G0, M, G1, K]
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
M
*
K
,
M
*
K
,
K
,
1
};
// Q layout [G0, G1, M, K]
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
Q
*
K
,
K
,
G1
Q
*
K
,
1
}
// Q layout [G0, M, G1
Q
, K]
:
std
::
vector
<
ck
::
index_t
>
{
G1
Q
*
M
*
K
,
M
*
K
,
K
,
1
};
// Q layout [G0, G1
Q
, M, K]
std
::
vector
<
ck
::
index_t
>
k_gs_ns_ks_lengths
{
G0
,
G1
,
N
,
K
};
std
::
vector
<
ck
::
index_t
>
k_gs_ns_ks_lengths
{
G0
,
G1
KV
,
N
,
K
};
std
::
vector
<
ck
::
index_t
>
k_gs_ns_ks_strides
=
input_permute
?
std
::
vector
<
ck
::
index_t
>
{
N
*
G1
*
K
,
K
,
G1
*
K
,
1
}
// K layout [G0, N, G1, K]
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
N
*
K
,
N
*
K
,
K
,
1
};
// K layout [G0, G1, N, K]
?
std
::
vector
<
ck
::
index_t
>
{
N
*
G1KV
*
K
,
K
,
G1KV
*
K
,
1
}
// K layout [G0, N, G1KV, K]
:
std
::
vector
<
ck
::
index_t
>
{
G1KV
*
N
*
K
,
N
*
K
,
K
,
1
};
// K layout [G0, G1KV, N, K]
std
::
vector
<
ck
::
index_t
>
v_gs_os_ns_lengths
{
G0
,
G1
,
O
,
N
};
std
::
vector
<
ck
::
index_t
>
v_gs_os_ns_lengths
{
G0
,
G1
KV
,
O
,
N
};
std
::
vector
<
ck
::
index_t
>
v_gs_os_ns_strides
=
input_permute
?
std
::
vector
<
ck
::
index_t
>
{
N
*
G1
*
O
,
O
,
1
,
G1
*
O
}
// V layout [G0, N, G1, O]
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
N
*
O
,
N
*
O
,
1
,
O
};
// V layout [G0, G1, N, O]
?
std
::
vector
<
ck
::
index_t
>
{
N
*
G1KV
*
O
,
O
,
1
,
G1KV
*
O
}
// V layout [G0, N, G1KV, O]
:
std
::
vector
<
ck
::
index_t
>
{
G1KV
*
N
*
O
,
N
*
O
,
1
,
O
};
// V layout [G0, G1KV, N, O]
std
::
vector
<
ck
::
index_t
>
y_gs_ms_os_lengths
{
G0
,
G1
,
M
,
O
};
std
::
vector
<
ck
::
index_t
>
y_gs_ms_os_lengths
{
G0
,
G1
Q
,
M
,
O
};
std
::
vector
<
ck
::
index_t
>
y_gs_ms_os_strides
=
output_permute
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
*
O
,
O
,
G1
*
O
,
1
}
// Y layout [G0, M, G1, O]
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
M
*
O
,
M
*
O
,
O
,
1
};
// Y layout [G0, G1, M, O]
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
Q
*
O
,
O
,
G1
Q
*
O
,
1
}
// Y layout [G0, M, G1
Q
, O]
:
std
::
vector
<
ck
::
index_t
>
{
G1
Q
*
M
*
O
,
M
*
O
,
O
,
1
};
// Y layout [G0, G1
Q
, M, O]
std
::
vector
<
ck
::
index_t
>
d0_gs_ms_ns_lengths
{
G0
,
G1
,
M
,
N
};
std
::
vector
<
ck
::
index_t
>
d0_gs_ms_ns_lengths
{
G0
,
G1
Q
,
M
,
N
};
std
::
vector
<
ck
::
index_t
>
d0_gs_ms_ns_strides
=
input_permute
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
*
N
,
N
,
G1
*
N
,
1
}
// d0 layout [G0, M, G1, N]
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
M
*
N
,
M
*
N
,
N
,
1
};
// d0 layout [G0, G1, M, N]
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
Q
*
N
,
N
,
G1
Q
*
N
,
1
}
// d0 layout [G0, M, G1
Q
, N]
:
std
::
vector
<
ck
::
index_t
>
{
G1
Q
*
M
*
N
,
M
*
N
,
N
,
1
};
// d0 layout [G0, G1
Q
, M, N]
std
::
vector
<
ck
::
index_t
>
z_gs_ms_ns_lengths
{
G0
,
G1
,
M
,
N
};
std
::
vector
<
ck
::
index_t
>
z_gs_ms_ns_lengths
{
G0
,
G1
Q
,
M
,
N
};
std
::
vector
<
ck
::
index_t
>
z_gs_ms_ns_strides
=
input_permute
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
*
N
,
N
,
G1
*
N
,
1
}
// Z layout [G0, M, G1, N]
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
M
*
N
,
M
*
N
,
N
,
1
};
// Z layout [G0, G1, M, N]
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1Q
*
N
,
N
,
G1Q
*
N
,
1
}
// Z layout [G0, M, G1Q, N]
:
std
::
vector
<
ck
::
index_t
>
{
G1Q
*
M
*
N
,
M
*
N
,
N
,
1
};
// Z layout [G0, G1Q, M, N]
std
::
vector
<
ck
::
index_t
>
kgrad_gs_ns_ks_lengths
{
G0
,
G1Q
,
N
,
K
};
std
::
vector
<
ck
::
index_t
>
kgrad_gs_ns_ks_strides
=
input_permute
?
std
::
vector
<
ck
::
index_t
>
{
N
*
G1Q
*
K
,
K
,
G1Q
*
K
,
1
}
// KGrad layout [G0, N, G1Q, K]
:
std
::
vector
<
ck
::
index_t
>
{
G1Q
*
N
*
K
,
N
*
K
,
K
,
1
};
// KGrad layout [G0, G1Q, N, K]
std
::
vector
<
ck
::
index_t
>
vgrad_gs_os_ns_lengths
{
G0
,
G1Q
,
O
,
N
};
std
::
vector
<
ck
::
index_t
>
vgrad_gs_os_ns_strides
=
input_permute
?
std
::
vector
<
ck
::
index_t
>
{
N
*
G1Q
*
O
,
O
,
1
,
G1Q
*
O
}
// VGrad layout [G0, N, G1Q, O]
:
std
::
vector
<
ck
::
index_t
>
{
G1Q
*
N
*
O
,
N
*
O
,
1
,
O
};
// VGrad layout [G0, G1Q, N, O]
// The softmax stat log-sum-exp (LSE) is used to speed up softmax calculation in backward
// pass Pi = exp(Si) / sum(exp(S0) + exp(S1) + ...)
// = exp(Si) / exp(log(sum(exp() + ...)))
// = exp(Si - log(sum(exp() + ...)))
// ^^^^^^^^^^^^^^^^^^^^^
// LSE
std
::
vector
<
ck
::
index_t
>
lse_gs_ms_lengths
{
G0
,
G1
,
M
};
std
::
vector
<
ck
::
index_t
>
lse_gs_ms_strides
{
G1
*
M
,
M
,
1
};
// LSE layout [G0, G1, M]
std
::
vector
<
ck
::
index_t
>
lse_gs_ms_lengths
{
G0
,
G1
Q
,
M
};
std
::
vector
<
ck
::
index_t
>
lse_gs_ms_strides
{
G1
Q
*
M
,
M
,
1
};
// LSE layout [G0, G1
Q
, M]
problem_descs
.
push_back
({
q_gs_ms_ks_lengths
,
q_gs_ms_ks_strides
,
...
...
@@ -439,13 +457,17 @@ int run(int argc, char* argv[])
y_gs_ms_os_strides
,
lse_gs_ms_lengths
,
lse_gs_ms_strides
,
kgrad_gs_ns_ks_lengths
,
kgrad_gs_ns_ks_strides
,
vgrad_gs_os_ns_lengths
,
vgrad_gs_os_ns_strides
,
d0_gs_ms_ns_lengths
,
d0_gs_ms_ns_strides
,
{},
// acc1_bias_gs_ms_os_lengths,
{},
// acc1_bias_gs_ms_os_strides,
});
int
BatchCount
=
G0
*
G1
;
int
BatchCount
=
G0
*
G1
Q
;
flop
+=
(
size_t
(
3
)
*
M
*
N
*
K
+
size_t
(
2
)
*
M
*
N
*
O
)
*
2
*
BatchCount
;
// Q/K/V/Y, dQ/dK/dV/dY, LSE
num_byte
+=
...
...
@@ -464,6 +486,8 @@ int run(int argc, char* argv[])
Tensor
<
InputDataType
>
y_gs_ms_os
(
y_gs_ms_os_lengths
,
y_gs_ms_os_strides
);
Tensor
<
InputDataType
>
ygrad_gs_ms_os
(
y_gs_ms_os_lengths
,
y_gs_ms_os_strides
);
Tensor
<
LSEDataType
>
lse_gs_ms
(
lse_gs_ms_lengths
,
lse_gs_ms_strides
);
Tensor
<
OutputDataType
>
kgrad_gs_ns_ks
(
kgrad_gs_ns_ks_lengths
,
kgrad_gs_ns_ks_strides
);
Tensor
<
OutputDataType
>
vgrad_gs_os_ns
(
vgrad_gs_os_ns_lengths
,
vgrad_gs_os_ns_strides
);
if
(
i
<
4
)
{
std
::
cout
<<
"q_gs_ms_ks: "
<<
q_gs_ms_ks
.
mDesc
<<
std
::
endl
;
...
...
@@ -473,6 +497,8 @@ int run(int argc, char* argv[])
std
::
cout
<<
"v_gs_os_ns: "
<<
v_gs_os_ns
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"y_gs_ms_os: "
<<
y_gs_ms_os
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"lse_gs_ms_os: "
<<
lse_gs_ms
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"kgrad_gs_ns_ks: "
<<
kgrad_gs_ns_ks
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"vgrad_gs_os_ns: "
<<
vgrad_gs_os_ns
.
mDesc
<<
std
::
endl
;
}
z_gs_ms_ns
.
GenerateTensorValue
(
GeneratorTensor_1
<
InputDataType
>
{
0
});
switch
(
init_method
)
...
...
@@ -510,7 +536,8 @@ int run(int argc, char* argv[])
q_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_1
<
InputDataType
>
{
1
});
k_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
InputDataType
>
{});
v_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
InputDataType
>
{});
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_Sequential
<
2
>
{});
// dy[g0, g1, m, o]
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_Sequential
<
2
>
{});
// dy[g0, g1q, m, o]
d0_gs_ms_ns
.
GenerateTensorValue
(
GeneratorTensor_1
<
Acc0BiasDataType
>
{
1
});
// dO dot O = [0; 1; 2; ...]
break
;
...
...
@@ -518,7 +545,8 @@ int run(int argc, char* argv[])
q_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_1
<
InputDataType
>
{
1
});
k_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
InputDataType
>
{});
v_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
InputDataType
>
{});
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_Sequential
<
3
>
{});
// dy[g0, g1, m, o]
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_Sequential
<
3
>
{});
// dy[g0, g1q, m, o]
d0_gs_ms_ns
.
GenerateTensorValue
(
GeneratorTensor_1
<
Acc0BiasDataType
>
{
1
});
// assume mnko = 256
// P = softmax(QK) = 0.0039 * ones
...
...
@@ -533,7 +561,7 @@ int run(int argc, char* argv[])
k_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
InputDataType
>
{});
v_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
InputDataType
>
{});
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_1
<
InputDataType
>
{
1
});
// dy[g0, g1, m, o]
GeneratorTensor_1
<
InputDataType
>
{
1
});
// dy[g0, g1
q
, m, o]
d0_gs_ms_ns
.
GenerateTensorValue
(
GeneratorTensor_1
<
Acc0BiasDataType
>
{
1
});
// assume mnko = 256
// P = softmax(QK) = 0.0039 * ones
...
...
@@ -556,16 +584,24 @@ int run(int argc, char* argv[])
Tensor
<
InputDataType
>
p_drop_g_m_n
({
BatchCount
,
M
,
N
});
q_gs_ms_ks
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
q_g_m_k
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
q_g_m_k
(
idx
[
0
]
*
G1
Q
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
});
k_gs_ns_ks
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
k_g_n_k
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
k_g_n_k
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
const
size_t
&
g0
=
idx
[
0
]
/
G1Q
;
const
size_t
&
g1q
=
idx
[
0
]
%
G1Q
;
const
size_t
&
g1kv
=
g1q
/
h_ratio
;
self
(
idx
)
=
k_gs_ns_ks
(
g0
,
g1kv
,
idx
[
1
],
idx
[
2
]);
});
d0_gs_ms_ns
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
d0_g_m_n
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
d0_g_m_n
(
idx
[
0
]
*
G1
Q
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
});
v_gs_os_ns
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
v_g_n_o
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
3
],
idx
[
2
])
=
self
(
idx
);
v_g_n_o
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
const
size_t
&
g0
=
idx
[
0
]
/
G1Q
;
const
size_t
&
g1q
=
idx
[
0
]
%
G1Q
;
const
size_t
&
g1kv
=
g1q
/
h_ratio
;
self
(
idx
)
=
v_gs_os_ns
(
g0
,
g1kv
,
idx
[
2
],
idx
[
1
]);
});
q_g_m_ks
.
push_back
(
q_g_m_k
);
...
...
@@ -586,6 +622,8 @@ int run(int argc, char* argv[])
z_tensors
.
push_back
(
z_gs_ms_ns
);
lse_tensors
.
push_back
(
lse_gs_ms
);
ygrad_tensors
.
push_back
(
ygrad_gs_ms_os
);
kgrad_tensors
.
push_back
(
kgrad_gs_ns_ks
);
vgrad_tensors
.
push_back
(
vgrad_gs_os_ns
);
q_tensors_device
.
emplace_back
(
std
::
make_unique
<
DeviceMem
>
(
sizeof
(
InputDataType
)
*
q_gs_ms_ks
.
GetElementSpaceSize
()));
k_tensors_device
.
emplace_back
(
...
...
@@ -602,12 +640,12 @@ int run(int argc, char* argv[])
std
::
make_unique
<
DeviceMem
>
(
sizeof
(
LSEDataType
)
*
lse_gs_ms
.
GetElementSpaceSize
()));
qgrad_tensors_device
.
emplace_back
(
std
::
make_unique
<
DeviceMem
>
(
sizeof
(
OutputDataType
)
*
q_gs_ms_ks
.
GetElementSpaceSize
()));
kgrad_tensors_device
.
emplace_back
(
std
::
make_unique
<
DeviceMem
>
(
sizeof
(
OutputDataType
)
*
k_gs_ns_ks
.
GetElementSpaceSize
()));
kgrad_tensors_device
.
emplace_back
(
std
::
make_unique
<
DeviceMem
>
(
sizeof
(
OutputDataType
)
*
k
grad
_gs_ns_ks
.
GetElementSpaceSize
()));
d0grad_tensors_device
.
emplace_back
(
std
::
make_unique
<
DeviceMem
>
(
sizeof
(
Acc0BiasDataType
)
*
d0_gs_ms_ns
.
GetElementSpaceSize
()));
vgrad_tensors_device
.
emplace_back
(
std
::
make_unique
<
DeviceMem
>
(
sizeof
(
OutputDataType
)
*
v_gs_os_ns
.
GetElementSpaceSize
()));
vgrad_tensors_device
.
emplace_back
(
std
::
make_unique
<
DeviceMem
>
(
sizeof
(
OutputDataType
)
*
v
grad
_gs_os_ns
.
GetElementSpaceSize
()));
ygrad_tensors_device
.
emplace_back
(
std
::
make_unique
<
DeviceMem
>
(
sizeof
(
InputDataType
)
*
y_gs_ms_os
.
GetElementSpaceSize
()));
q_tensors_device
.
back
()
->
ToDevice
(
q_gs_ms_ks
.
data
());
...
...
@@ -713,11 +751,11 @@ int run(int argc, char* argv[])
for
(
std
::
size_t
i
=
0
;
i
<
group_count
;
i
++
)
{
int
G1
=
v
_tensors
[
i
].
GetLengths
()[
1
];
int
G1
Q
=
q
_tensors
[
i
].
GetLengths
()[
1
];
// copy z matirx data form device
z_tensors_device
[
i
]
->
FromDevice
(
z_tensors
[
i
].
mData
.
data
());
z_tensors
[
i
].
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
z_g_m_ns
[
i
](
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
z_g_m_ns
[
i
](
idx
[
0
]
*
G1
Q
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
});
run_attention_fwd_host
(
q_g_m_ks
[
i
],
k_g_n_ks
[
i
],
...
...
@@ -734,11 +772,11 @@ int run(int argc, char* argv[])
rp_dropout
);
y_tensors
[
i
].
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
self
(
idx
)
=
y_g_m_os
[
i
](
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
]);
self
(
idx
)
=
y_g_m_os
[
i
](
idx
[
0
]
*
G1
Q
+
idx
[
1
],
idx
[
2
],
idx
[
3
]);
});
y_tensors_device
[
i
]
->
ToDevice
(
y_tensors
[
i
].
data
());
lse_tensors
[
i
].
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
self
(
idx
)
=
lse_g_ms
[
i
](
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
]);
self
(
idx
)
=
lse_g_ms
[
i
](
idx
[
0
]
*
G1
Q
+
idx
[
1
],
idx
[
2
]);
});
lse_tensors_device
[
i
]
->
ToDevice
(
lse_tensors
[
i
].
data
());
qgrad_tensors_device
[
i
]
->
SetZero
();
...
...
@@ -752,13 +790,13 @@ int run(int argc, char* argv[])
for
(
std
::
size_t
i
=
0
;
i
<
group_count
;
i
++
)
{
int
G0
=
v
_tensors
[
i
].
GetLengths
()[
0
];
int
G1
=
v
_tensors
[
i
].
GetLengths
()[
1
];
int
G0
=
q
_tensors
[
i
].
GetLengths
()[
0
];
int
G1
Q
=
q
_tensors
[
i
].
GetLengths
()[
1
];
int
O
=
v_tensors
[
i
].
GetLengths
()[
2
];
int
N
=
v_tensors
[
i
].
GetLengths
()[
3
];
int
M
=
q_tensors
[
i
].
GetLengths
()[
2
];
int
K
=
q_tensors
[
i
].
GetLengths
()[
3
];
int
BatchCount
=
G0
*
G1
;
int
BatchCount
=
G0
*
G1
Q
;
Tensor
<
OutputDataType
>
qgrad_g_m_k
({
BatchCount
,
M
,
K
});
Tensor
<
OutputDataType
>
kgrad_g_n_k
({
BatchCount
,
N
,
K
});
Tensor
<
OutputDataType
>
vgrad_g_n_o
({
BatchCount
,
N
,
O
});
...
...
@@ -768,7 +806,7 @@ int run(int argc, char* argv[])
Tensor
<
InputDataType
>
ygrad_g_m_o
({
BatchCount
,
M
,
O
});
ygrad_tensors
[
i
].
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
ygrad_g_m_o
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
ygrad_g_m_o
(
idx
[
0
]
*
G1
Q
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
});
auto
ref_gemm0_grad
=
ReferenceGemm0GradInstance
{};
auto
ref_gemm0_grad_invoker
=
ref_gemm0_grad
.
MakeInvoker
();
...
...
@@ -814,21 +852,21 @@ int run(int argc, char* argv[])
Tensor
<
OutputDataType
>
qgrad_gs_ms_ks_host_result
(
q_tensors
[
i
].
GetLengths
(),
q_tensors
[
i
].
GetStrides
());
Tensor
<
OutputDataType
>
kgrad_gs_ns_ks_host_result
(
k_tensors
[
i
].
GetLengths
(),
k_tensors
[
i
].
GetStrides
());
Tensor
<
OutputDataType
>
kgrad_gs_ns_ks_host_result
(
k
grad
_tensors
[
i
].
GetLengths
(),
k
grad
_tensors
[
i
].
GetStrides
());
Tensor
<
Acc0BiasDataType
>
d0grad_gs_ms_ns_host_result
(
d0_tensors
[
i
].
GetLengths
(),
d0_tensors
[
i
].
GetStrides
());
Tensor
<
OutputDataType
>
vgrad_gs_os_ns_host_result
(
v_tensors
[
i
].
GetLengths
(),
v_tensors
[
i
].
GetStrides
());
Tensor
<
OutputDataType
>
vgrad_gs_os_ns_host_result
(
v
grad
_tensors
[
i
].
GetLengths
(),
v
grad
_tensors
[
i
].
GetStrides
());
Tensor
<
OutputDataType
>
qgrad_gs_ms_ks_device_result
(
q_tensors
[
i
].
GetLengths
(),
q_tensors
[
i
].
GetStrides
());
Tensor
<
OutputDataType
>
kgrad_gs_ns_ks_device_result
(
k_tensors
[
i
].
GetLengths
(),
k_tensors
[
i
].
GetStrides
());
Tensor
<
OutputDataType
>
kgrad_gs_ns_ks_device_result
(
k
grad
_tensors
[
i
].
GetLengths
(),
k
grad
_tensors
[
i
].
GetStrides
());
Tensor
<
Acc0BiasDataType
>
d0grad_gs_ms_ns_device_result
(
d0_tensors
[
i
].
GetLengths
(),
d0_tensors
[
i
].
GetStrides
());
Tensor
<
OutputDataType
>
vgrad_gs_os_ns_device_result
(
v_tensors
[
i
].
GetLengths
(),
v_tensors
[
i
].
GetStrides
());
Tensor
<
OutputDataType
>
vgrad_gs_os_ns_device_result
(
v
grad
_tensors
[
i
].
GetLengths
(),
v
grad
_tensors
[
i
].
GetStrides
());
qgrad_tensors_device
[
i
]
->
FromDevice
(
qgrad_gs_ms_ks_device_result
.
data
());
kgrad_tensors_device
[
i
]
->
FromDevice
(
kgrad_gs_ns_ks_device_result
.
data
());
...
...
@@ -836,34 +874,34 @@ int run(int argc, char* argv[])
vgrad_tensors_device
[
i
]
->
FromDevice
(
vgrad_gs_os_ns_device_result
.
data
());
// permute
qgrad_gs_ms_ks_host_result
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
const
size_t
&
g0
=
idx
[
0
];
const
size_t
&
g1
=
idx
[
1
];
const
size_t
&
g0
=
idx
[
0
];
const
size_t
&
g1
q
=
idx
[
1
];
const
size_t
g
=
g0
*
G1
+
g1
;
const
size_t
g
=
g0
*
G1
Q
+
g1
q
;
self
(
idx
)
=
qgrad_g_m_k
(
g
,
idx
[
2
],
idx
[
3
]);
});
kgrad_gs_ns_ks_host_result
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
const
size_t
&
g0
=
idx
[
0
];
const
size_t
&
g1
=
idx
[
1
];
const
size_t
&
g0
=
idx
[
0
];
const
size_t
&
g1
q
=
idx
[
1
];
const
size_t
g
=
g0
*
G1
+
g1
;
const
size_t
g
=
g0
*
G1
Q
+
g1
q
;
self
(
idx
)
=
kgrad_g_n_k
(
g
,
idx
[
2
],
idx
[
3
]);
});
d0grad_gs_ms_ns_host_result
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
const
size_t
&
g0
=
idx
[
0
];
const
size_t
&
g1
=
idx
[
1
];
const
size_t
&
g0
=
idx
[
0
];
const
size_t
&
g1
q
=
idx
[
1
];
const
size_t
g
=
g0
*
G1
+
g1
;
const
size_t
g
=
g0
*
G1
Q
+
g1
q
;
self
(
idx
)
=
sgrad_g_m_n
(
g
,
idx
[
2
],
idx
[
3
]);
});
vgrad_gs_os_ns_host_result
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
const
size_t
&
g0
=
idx
[
0
];
const
size_t
&
g1
=
idx
[
1
];
const
size_t
&
g0
=
idx
[
0
];
const
size_t
&
g1
q
=
idx
[
1
];
const
size_t
g
=
g0
*
G1
+
g1
;
const
size_t
g
=
g0
*
G1
Q
+
g1
q
;
self
(
idx
)
=
vgrad_g_n_o
(
g
,
idx
[
3
],
idx
[
2
]);
});
...
...
example/52_flash_atten_bias/grouped_multihead_attention_bias_forward_v2.cpp
View file @
e87ddb0e
...
...
@@ -71,11 +71,10 @@ static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecializatio
static
constexpr
auto
MaskingSpec
=
ck
::
tensor_operation
::
device
::
MaskingSpecialization
::
MaskDisabled
;
static
constexpr
auto
TensorSpecA
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecB0
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecB1
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecC
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
bool
Deterministic
=
false
;
static
constexpr
auto
TensorSpecA
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecB0
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecB1
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecC
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
#if(DIM <= 32)
using
DeviceGemmInstance
=
...
...
@@ -149,8 +148,7 @@ using DeviceGemmInstance =
S
<
1
,
64
,
1
,
4
>
,
// CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8
,
// CShuffleBlockTransferScalarPerVector_NPerBlock
1
,
MaskingSpec
,
// MaskingSpecialization
Deterministic
>
;
MaskingSpec
>
;
// MaskingSpecialization
#elif(DIM <= 64)
using
DeviceGemmInstance
=
ck
::
tensor_operation
::
device
::
DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
<
...
...
@@ -223,8 +221,7 @@ using DeviceGemmInstance =
S
<
1
,
32
,
1
,
8
>
,
// CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8
,
// CShuffleBlockTransferScalarPerVector_NPerBlock
1
,
MaskingSpec
,
// MaskingSpecialization
Deterministic
>
;
MaskingSpec
>
;
// MaskingSpecialization
#elif(DIM <= 128)
using
DeviceGemmInstance
=
ck
::
tensor_operation
::
device
::
DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
<
...
...
@@ -297,8 +294,7 @@ using DeviceGemmInstance =
S
<
1
,
32
,
1
,
8
>
,
// CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8
,
// CShuffleBlockTransferScalarPerVector_NPerBlock
1
,
MaskingSpec
,
// MaskingSpecialization
Deterministic
>
;
MaskingSpec
>
;
// MaskingSpecialization
#endif
// Ref Gemm0: DataType in, AccDataType out
...
...
example/52_flash_atten_bias/run_batched_multihead_attention_bias_forward_v2.inc
View file @
e87ddb0e
...
...
@@ -14,11 +14,12 @@ int run(int argc, char* argv[])
ck
::
index_t
K
=
DIM
;
ck
::
index_t
O
=
DIM
;
// Output shape C[G0, M, G1, O]. Batch dim, outer dim, inner dim must match GEMM shape
// C_g0_g1_m_o = reshape(C_g_m_o, [g0, g1, m, o])
// C_g0_m_g1_o = permute(C_g0_g1_m_o, [0, 2, 1, 3])
ck
::
index_t
G0
=
7
;
ck
::
index_t
G1
=
13
;
// Output shape C[G0, M, G1Q, O]. Batch dim, outer dim, inner dim must match GEMM shape
// C_g0_g1q_m_o = reshape(C_g_m_o, [g0, g1q, m, o])
// C_g0_m_g1q_o = permute(C_g0_g1q_m_o, [0, 2, 1, 3])
ck
::
index_t
G0
=
7
;
ck
::
index_t
G1Q
=
12
;
// h_q
ck
::
index_t
G1KV
=
12
;
// h_kv
bool
input_permute
=
false
;
bool
output_permute
=
true
;
...
...
@@ -37,32 +38,33 @@ int run(int argc, char* argv[])
init_method
=
std
::
stoi
(
argv
[
2
]);
time_kernel
=
std
::
stoi
(
argv
[
3
]);
}
else
if
(
argc
==
1
3
)
else
if
(
argc
==
1
4
)
{
do_verification
=
std
::
stoi
(
argv
[
1
]);
init_method
=
std
::
stoi
(
argv
[
2
]);
time_kernel
=
std
::
stoi
(
argv
[
3
]);
M
=
std
::
stoi
(
argv
[
4
]);
N
=
std
::
stoi
(
argv
[
5
]);
K
=
std
::
stoi
(
argv
[
6
]);
O
=
std
::
stoi
(
argv
[
7
]);
G0
=
std
::
stoi
(
argv
[
8
]);
G1
=
std
::
stoi
(
argv
[
9
]);
M
=
std
::
stoi
(
argv
[
4
]);
N
=
std
::
stoi
(
argv
[
5
]);
K
=
std
::
stoi
(
argv
[
6
]);
O
=
std
::
stoi
(
argv
[
7
]);
G0
=
std
::
stoi
(
argv
[
8
]);
G1Q
=
std
::
stoi
(
argv
[
9
]);
G1KV
=
std
::
stoi
(
argv
[
10
]);
p_drop
=
std
::
stof
(
argv
[
1
0
]);
p_drop
=
std
::
stof
(
argv
[
1
1
]);
input_permute
=
std
::
stoi
(
argv
[
1
1
]);
output_permute
=
std
::
stoi
(
argv
[
1
2
]);
input_permute
=
std
::
stoi
(
argv
[
1
2
]);
output_permute
=
std
::
stoi
(
argv
[
1
3
]);
}
else
{
printf
(
"arg1: verification (0=no, 1=yes)
\n
"
);
printf
(
"arg2: initialization (0=no init, 1=integer value, 2=decimal value)
\n
"
);
printf
(
"arg3: time kernel (0=no, 1=yes)
\n
"
);
printf
(
"arg4 to 1
1
: M, N, K, O, G0, G1
\n
"
);
printf
(
"arg1
0
:
scale (alpha)
\n
"
);
printf
(
"arg1
1
to 1
2
: input / output permute
\n
"
);
printf
(
"arg4 to 1
0
: M, N, K, O, G0, G1
Q, G1KV
\n
"
);
printf
(
"arg1
1
:
p_drop
\n
"
);
printf
(
"arg1
2
to 1
3
: input / output permute
\n
"
);
exit
(
0
);
}
...
...
@@ -71,45 +73,45 @@ int run(int argc, char* argv[])
float
rp_dropout
=
1.0
/
p_dropout
;
float
alpha
=
1.
f
/
std
::
sqrt
(
K
);
std
::
vector
<
ck
::
index_t
>
a_gs_ms_ks_lengths
{
G0
,
G1
,
M
,
K
};
std
::
vector
<
ck
::
index_t
>
a_gs_ms_ks_lengths
{
G0
,
G1
Q
,
M
,
K
};
std
::
vector
<
ck
::
index_t
>
a_gs_ms_ks_strides
=
input_permute
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
*
K
,
K
,
G1
*
K
,
1
}
// A layout [G0, M, G1, K]
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
M
*
K
,
M
*
K
,
K
,
1
};
// A layout [G0, G1, M, K]
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
Q
*
K
,
K
,
G1
Q
*
K
,
1
}
// A layout [G0, M, G1
Q
, K]
:
std
::
vector
<
ck
::
index_t
>
{
G1
Q
*
M
*
K
,
M
*
K
,
K
,
1
};
// A layout [G0, G1
Q
, M, K]
std
::
vector
<
ck
::
index_t
>
b0_gs_ns_ks_lengths
{
G0
,
G1
,
N
,
K
};
std
::
vector
<
ck
::
index_t
>
b0_gs_ns_ks_lengths
{
G0
,
G1
KV
,
N
,
K
};
std
::
vector
<
ck
::
index_t
>
b0_gs_ns_ks_strides
=
input_permute
?
std
::
vector
<
ck
::
index_t
>
{
N
*
G1
*
K
,
K
,
G1
*
K
,
1
}
// B0 layout [G0, N, G1, K]
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
N
*
K
,
N
*
K
,
K
,
1
};
// B0 layout [G0, G1, N, K]
?
std
::
vector
<
ck
::
index_t
>
{
N
*
G1
KV
*
K
,
K
,
G1
KV
*
K
,
1
}
// B0 layout [G0, N, G1
KV
, K]
:
std
::
vector
<
ck
::
index_t
>
{
G1
KV
*
N
*
K
,
N
*
K
,
K
,
1
};
// B0 layout [G0, G1
KV
, N, K]
std
::
vector
<
ck
::
index_t
>
b1_gs_os_ns_lengths
{
G0
,
G1
,
O
,
N
};
std
::
vector
<
ck
::
index_t
>
b1_gs_os_ns_lengths
{
G0
,
G1
KV
,
O
,
N
};
std
::
vector
<
ck
::
index_t
>
b1_gs_os_ns_strides
=
input_permute
?
std
::
vector
<
ck
::
index_t
>
{
N
*
G1
*
O
,
O
,
1
,
G1
*
O
}
// B1 layout [G0, N, G1, O]
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
N
*
O
,
N
*
O
,
1
,
O
};
// B1 layout [G0, G1, N, O]
?
std
::
vector
<
ck
::
index_t
>
{
N
*
G1
KV
*
O
,
O
,
1
,
G1
KV
*
O
}
// B1 layout [G0, N, G1
KV
, O]
:
std
::
vector
<
ck
::
index_t
>
{
G1
KV
*
N
*
O
,
N
*
O
,
1
,
O
};
// B1 layout [G0, G1
KV
, N, O]
std
::
vector
<
ck
::
index_t
>
c_gs_ms_os_lengths
{
G0
,
G1
,
M
,
O
};
std
::
vector
<
ck
::
index_t
>
c_gs_ms_os_lengths
{
G0
,
G1
Q
,
M
,
O
};
std
::
vector
<
ck
::
index_t
>
c_gs_ms_os_strides
=
output_permute
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
*
O
,
O
,
G1
*
O
,
1
}
// C layout [G0, M, G1, O]
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
M
*
O
,
M
*
O
,
O
,
1
};
// C layout [G0, G1, M, O]
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
Q
*
O
,
O
,
G1
Q
*
O
,
1
}
// C layout [G0, M, G1
Q
, O]
:
std
::
vector
<
ck
::
index_t
>
{
G1
Q
*
M
*
O
,
M
*
O
,
O
,
1
};
// C layout [G0, G1
Q
, M, O]
std
::
vector
<
ck
::
index_t
>
d_gs_ms_ns_lengths
{
G0
,
G1
,
M
,
N
};
std
::
vector
<
ck
::
index_t
>
d_gs_ms_ns_lengths
{
G0
,
G1
Q
,
M
,
N
};
std
::
vector
<
ck
::
index_t
>
d_gs_ms_ns_strides
=
input_permute
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
*
N
,
N
,
G1
*
N
,
1
}
// D layout [G0, M, G1, N]
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
M
*
N
,
M
*
N
,
N
,
1
};
// D layout [G0, G1, M, N]
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
Q
*
N
,
N
,
G1
Q
*
N
,
1
}
// D layout [G0, M, G1
Q
, N]
:
std
::
vector
<
ck
::
index_t
>
{
G1
Q
*
M
*
N
,
M
*
N
,
N
,
1
};
// D layout [G0, G1
Q
, M, N]
std
::
vector
<
ck
::
index_t
>
z_gs_ms_ns_lengths
{
G0
,
G1
,
M
,
N
};
std
::
vector
<
ck
::
index_t
>
z_gs_ms_ns_lengths
{
G0
,
G1
Q
,
M
,
N
};
std
::
vector
<
ck
::
index_t
>
z_gs_ms_ns_strides
=
input_permute
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
*
N
,
N
,
G1
*
N
,
1
}
// Z layout [G0, M, G1, N]
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
M
*
N
,
M
*
N
,
N
,
1
};
// Z layout [G0, G1, M, N]
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
Q
*
N
,
N
,
G1
Q
*
N
,
1
}
// Z layout [G0, M, G1
Q
, N]
:
std
::
vector
<
ck
::
index_t
>
{
G1
Q
*
M
*
N
,
M
*
N
,
N
,
1
};
// Z layout [G0, G1
Q
, M, N]
std
::
vector
<
ck
::
index_t
>
lse_gs_ms_lengths
{
G0
,
G1
,
M
};
std
::
vector
<
ck
::
index_t
>
lse_gs_ms_lengths
{
G0
,
G1
Q
,
M
};
std
::
vector
<
ck
::
index_t
>
lse_gs_ms_strides
=
std
::
vector
<
ck
::
index_t
>
{
G1
*
M
,
M
,
1
};
// LSE layout [G0, G1, M]
std
::
vector
<
ck
::
index_t
>
{
G1
Q
*
M
,
M
,
1
};
// LSE layout [G0, G1
Q
, M]
Tensor
<
ADataType
>
a_gs_ms_ks
(
a_gs_ms_ks_lengths
,
a_gs_ms_ks_strides
);
Tensor
<
B0DataType
>
b0_gs_ns_ks
(
b0_gs_ns_ks_lengths
,
b0_gs_ns_ks_strides
);
...
...
@@ -224,7 +226,7 @@ int run(int argc, char* argv[])
return
0
;
}
ck
::
index_t
BatchCount
=
G0
*
G1
;
ck
::
index_t
BatchCount
=
G0
*
G1
Q
;
float
ave_time
=
invoker
.
Run
(
argument
,
StreamConfig
{
nullptr
,
time_kernel
});
...
...
@@ -312,29 +314,37 @@ int run(int argc, char* argv[])
Tensor
<
B1DataType
>
b1_g_n_o
({
BatchCount
,
N
,
O
});
Tensor
<
AccDataType
>
acc0_g_m_n
({
BatchCount
,
M
,
N
});
// scratch object after gemm0
Tensor
<
ADataType
>
a1_g_m_n
({
BatchCount
,
M
,
N
});
// scratch object after softmax
Tensor
<
ADataType
>
a1_g_m_n_drop
({
G0
*
G1
,
M
,
N
});
Tensor
<
ADataType
>
a1_g_m_n_drop
({
BatchCount
,
M
,
N
});
Tensor
<
LSEDataType
>
lse_g_m_host_result
(
{
BatchCount
,
M
});
// scratch object after max + ln(sum)
Tensor
<
Acc0BiasDataType
>
d_g_m_n
({
G0
*
G1
,
M
,
N
});
Tensor
<
ZDataType
>
z_g_m_n
({
G0
*
G1
,
M
,
N
});
Tensor
<
Acc0BiasDataType
>
d_g_m_n
({
BatchCount
,
M
,
N
});
Tensor
<
ZDataType
>
z_g_m_n
({
BatchCount
,
M
,
N
});
Tensor
<
CDataType
>
c_g_m_o_host_result
({
BatchCount
,
M
,
O
});
// scratch object after gemm1
// permute
a_gs_ms_ks
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
a_g_m_k
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
a_g_m_k
(
idx
[
0
]
*
G1
Q
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
});
b0_gs_ns_ks
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
b0_g_k_n
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
3
],
idx
[
2
])
=
self
(
idx
);
b0_g_k_n
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
const
size_t
&
g0
=
idx
[
0
]
/
G1Q
;
const
size_t
&
g1q
=
idx
[
0
]
%
G1Q
;
const
size_t
&
g1kv
=
g1q
/
(
G1Q
/
G1KV
);
self
(
idx
)
=
b0_gs_ns_ks
(
g0
,
g1kv
,
idx
[
2
],
idx
[
1
]);
});
b1_gs_os_ns
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
b1_g_n_o
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
3
],
idx
[
2
])
=
self
(
idx
);
b1_g_n_o
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
const
size_t
&
g0
=
idx
[
0
]
/
G1Q
;
const
size_t
&
g1q
=
idx
[
0
]
%
G1Q
;
const
size_t
&
g1kv
=
g1q
/
(
G1Q
/
G1KV
);
self
(
idx
)
=
b1_gs_os_ns
(
g0
,
g1kv
,
idx
[
2
],
idx
[
1
]);
});
d_gs_ms_ns
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
d_g_m_n
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
d_g_m_n
(
idx
[
0
]
*
G1
Q
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
});
z_gs_ms_ns
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
z_g_m_n
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
z_g_m_n
(
idx
[
0
]
*
G1
Q
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
});
// gemm 0
...
...
@@ -384,18 +394,18 @@ int run(int argc, char* argv[])
// permute
c_gs_ms_os_host_result
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
const
size_t
&
g0
=
idx
[
0
];
const
size_t
&
g1
=
idx
[
1
];
const
size_t
&
g0
=
idx
[
0
];
const
size_t
&
g1
q
=
idx
[
1
];
const
size_t
g
=
g0
*
G1
+
g1
;
const
size_t
g
=
g0
*
G1
Q
+
g1
q
;
self
(
idx
)
=
c_g_m_o_host_result
(
g
,
idx
[
2
],
idx
[
3
]);
});
lse_gs_ms_host_result
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
const
size_t
&
g0
=
idx
[
0
];
const
size_t
&
g1
=
idx
[
1
];
const
size_t
&
g0
=
idx
[
0
];
const
size_t
&
g1
q
=
idx
[
1
];
const
size_t
g
=
g0
*
G1
+
g1
;
const
size_t
g
=
g0
*
G1
Q
+
g1
q
;
self
(
idx
)
=
lse_g_m_host_result
(
g
,
idx
[
2
]);
});
...
...
example/52_flash_atten_bias/run_grouped_multihead_attention_bias_forward_v2.inc
View file @
e87ddb0e
...
...
@@ -11,6 +11,7 @@ int run(int argc, char* argv[])
bool
output_permute
=
true
;
float
p_drop
=
0.2
;
int
h_ratio
=
1
;
// G1Q / G1KV
const
unsigned
long
long
seed
=
1
;
const
unsigned
long
long
offset
=
0
;
...
...
@@ -24,22 +25,25 @@ int run(int argc, char* argv[])
init_method
=
std
::
stoi
(
argv
[
2
]);
time_kernel
=
std
::
stoi
(
argv
[
3
]);
}
else
if
(
argc
==
7
)
else
if
(
argc
==
8
)
{
do_verification
=
std
::
stoi
(
argv
[
1
]);
init_method
=
std
::
stoi
(
argv
[
2
]);
time_kernel
=
std
::
stoi
(
argv
[
3
]);
p_drop
=
std
::
stoi
(
argv
[
4
]);
input_permute
=
std
::
stoi
(
argv
[
5
]);
output_permute
=
std
::
stoi
(
argv
[
6
]);
h_ratio
=
std
::
stof
(
argv
[
5
]);
input_permute
=
std
::
stoi
(
argv
[
6
]);
output_permute
=
std
::
stoi
(
argv
[
7
]);
}
else
{
printf
(
"arg1: verification (0=no, 1=yes)
\n
"
);
printf
(
"arg2: initialization (0=no init, 1=integer value, 2=decimal value)
\n
"
);
printf
(
"arg3: time kernel (0=no, 1=yes)
\n
"
);
printf
(
"arg4 to 5: input / output permute
\n
"
);
printf
(
"arg4: p_drop
\n
"
);
printf
(
"arg5: h_ratio
\n
"
);
printf
(
"arg6 to 7: input / output permute
\n
"
);
exit
(
0
);
}
...
...
@@ -61,7 +65,7 @@ int run(int argc, char* argv[])
std
::
vector
<
void
*>
p_z
;
// for result verification
std
::
vector
<
void
*>
p_z_nullptr
;
// for time test
std
::
vector
<
void
*>
p_lse
;
std
::
vector
<
std
::
vector
<
int
>>
g0_g1_m_n_k_o
;
std
::
vector
<
std
::
vector
<
int
>>
g0_g1
q
_m_n_k_o
;
std
::
vector
<
Tensor
<
ADataType
>>
a_tensors
;
std
::
vector
<
Tensor
<
B0DataType
>>
b0_tensors
;
...
...
@@ -86,54 +90,57 @@ int run(int argc, char* argv[])
for
(
std
::
size_t
i
=
0
;
i
<
group_count
;
i
++
)
{
int
M
=
128
*
(
rand
()
%
8
)
+
(
rand
()
%
128
);
int
N
=
128
*
(
rand
()
%
8
)
+
(
rand
()
%
128
);
int
K
=
DIM
;
int
O
=
DIM
;
int
G0
=
rand
()
%
3
+
1
;
int
G1
=
rand
()
%
5
+
1
;
int
M
=
128
*
(
rand
()
%
8
)
+
(
rand
()
%
128
);
int
N
=
128
*
(
rand
()
%
8
)
+
(
rand
()
%
128
);
int
K
=
DIM
;
int
O
=
DIM
;
int
G0
=
rand
()
%
3
+
1
;
int
G1KV
=
rand
()
%
5
+
1
;
int
G1Q
=
G1KV
*
h_ratio
;
g0_g1_m_n_k_o
.
push_back
({
G0
,
G1
,
M
,
N
,
K
,
O
});
g0_g1
q
_m_n_k_o
.
push_back
({
G0
,
G1
Q
,
M
,
N
,
K
,
O
});
std
::
vector
<
ck
::
index_t
>
a_gs_ms_ks_lengths
{
G0
,
G1
,
M
,
K
};
std
::
vector
<
ck
::
index_t
>
a_gs_ms_ks_lengths
{
G0
,
G1
Q
,
M
,
K
};
std
::
vector
<
ck
::
index_t
>
a_gs_ms_ks_strides
=
input_permute
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
*
K
,
K
,
G1
*
K
,
1
}
// A layout [G0, M, G1, K]
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
M
*
K
,
M
*
K
,
K
,
1
};
// A layout [G0, G1, M, K]
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
Q
*
K
,
K
,
G1
Q
*
K
,
1
}
// A layout [G0, M, G1
Q
, K]
:
std
::
vector
<
ck
::
index_t
>
{
G1
Q
*
M
*
K
,
M
*
K
,
K
,
1
};
// A layout [G0, G1
Q
, M, K]
std
::
vector
<
ck
::
index_t
>
b0_gs_ns_ks_lengths
{
G0
,
G1
,
N
,
K
};
std
::
vector
<
ck
::
index_t
>
b0_gs_ns_ks_lengths
{
G0
,
G1
KV
,
N
,
K
};
std
::
vector
<
ck
::
index_t
>
b0_gs_ns_ks_strides
=
input_permute
?
std
::
vector
<
ck
::
index_t
>
{
N
*
G1
*
K
,
K
,
G1
*
K
,
1
}
// B0 layout [G0, N, G1, K]
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
N
*
K
,
N
*
K
,
K
,
1
};
// B0 layout [G0, G1, N, K]
?
std
::
vector
<
ck
::
index_t
>
{
N
*
G1KV
*
K
,
K
,
G1KV
*
K
,
1
}
// B0 layout [G0, N, G1KV, K]
:
std
::
vector
<
ck
::
index_t
>
{
G1KV
*
N
*
K
,
N
*
K
,
K
,
1
};
// B0 layout [G0, G1KV, N, K]
std
::
vector
<
ck
::
index_t
>
b1_gs_os_ns_lengths
{
G0
,
G1
,
O
,
N
};
std
::
vector
<
ck
::
index_t
>
b1_gs_os_ns_lengths
{
G0
,
G1
KV
,
O
,
N
};
std
::
vector
<
ck
::
index_t
>
b1_gs_os_ns_strides
=
input_permute
?
std
::
vector
<
ck
::
index_t
>
{
N
*
G1
*
O
,
O
,
1
,
G1
*
O
}
// B1 layout [G0, N, G1, O]
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
N
*
O
,
N
*
O
,
1
,
O
};
// B1 layout [G0, G1, N, O]
?
std
::
vector
<
ck
::
index_t
>
{
N
*
G1KV
*
O
,
O
,
1
,
G1KV
*
O
}
// B1 layout [G0, N, G1KV, O]
:
std
::
vector
<
ck
::
index_t
>
{
G1KV
*
N
*
O
,
N
*
O
,
1
,
O
};
// B1 layout [G0, G1KV, N, O]
std
::
vector
<
ck
::
index_t
>
c_gs_ms_os_lengths
{
G0
,
G1
,
M
,
O
};
std
::
vector
<
ck
::
index_t
>
c_gs_ms_os_lengths
{
G0
,
G1
Q
,
M
,
O
};
std
::
vector
<
ck
::
index_t
>
c_gs_ms_os_strides
=
output_permute
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
*
O
,
O
,
G1
*
O
,
1
}
// C layout [G0, M, G1, O]
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
M
*
O
,
M
*
O
,
O
,
1
};
// C layout [G0, G1, M, O]
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
Q
*
O
,
O
,
G1
Q
*
O
,
1
}
// C layout [G0, M, G1
Q
, O]
:
std
::
vector
<
ck
::
index_t
>
{
G1
Q
*
M
*
O
,
M
*
O
,
O
,
1
};
// C layout [G0, G1
Q
, M, O]
std
::
vector
<
ck
::
index_t
>
d_gs_ms_ns_lengths
{
G0
,
G1
,
M
,
N
};
std
::
vector
<
ck
::
index_t
>
d_gs_ms_ns_lengths
{
G0
,
G1
Q
,
M
,
N
};
std
::
vector
<
ck
::
index_t
>
d_gs_ms_ns_strides
=
input_permute
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
*
N
,
N
,
G1
*
N
,
1
}
// D layout [G0, M, G1, N]
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
M
*
N
,
M
*
N
,
N
,
1
};
// D layout [G0, G1, M, N]
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
Q
*
N
,
N
,
G1
Q
*
N
,
1
}
// D layout [G0, M, G1
Q
, N]
:
std
::
vector
<
ck
::
index_t
>
{
G1
Q
*
M
*
N
,
M
*
N
,
N
,
1
};
// D layout [G0, G1
Q
, M, N]
std
::
vector
<
ck
::
index_t
>
z_gs_ms_ns_lengths
{
G0
,
G1
,
M
,
N
};
std
::
vector
<
ck
::
index_t
>
z_gs_ms_ns_lengths
{
G0
,
G1
Q
,
M
,
N
};
std
::
vector
<
ck
::
index_t
>
z_gs_ms_ns_strides
=
input_permute
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
*
N
,
N
,
G1
*
N
,
1
}
// Z layout [G0, M, G1, N]
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
M
*
N
,
M
*
N
,
N
,
1
};
// Z layout [G0, G1, M, N]
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
Q
*
N
,
N
,
G1
Q
*
N
,
1
}
// Z layout [G0, M, G1
Q
, N]
:
std
::
vector
<
ck
::
index_t
>
{
G1
Q
*
M
*
N
,
M
*
N
,
N
,
1
};
// Z layout [G0, G1
Q
, M, N]
std
::
vector
<
ck
::
index_t
>
lse_gs_ms_lengths
{
G0
,
G1
,
M
};
std
::
vector
<
ck
::
index_t
>
lse_gs_ms_lengths
{
G0
,
G1
Q
,
M
};
std
::
vector
<
ck
::
index_t
>
lse_gs_ms_strides
=
std
::
vector
<
ck
::
index_t
>
{
G1
*
M
,
M
,
1
};
// LSE layout [G0, G1, M]
std
::
vector
<
ck
::
index_t
>
{
G1
Q
*
M
,
M
,
1
};
// LSE layout [G0, G1
Q
, M]
problem_descs
.
push_back
({
a_gs_ms_ks_lengths
,
a_gs_ms_ks_strides
,
...
...
@@ -161,7 +168,7 @@ int run(int argc, char* argv[])
Tensor
<
ZDataType
>
z_gs_ms_ns
(
z_gs_ms_ns_lengths
,
z_gs_ms_ns_strides
);
Tensor
<
LSEDataType
>
lse_gs_ms_device_result
(
lse_gs_ms_lengths
,
lse_gs_ms_strides
);
int
Batch
=
G0
*
G1
;
int
Batch
=
G0
*
G1
Q
;
flop
+=
(
size_t
(
M
)
*
N
*
K
*
2
+
size_t
(
M
)
*
N
*
O
*
2
)
*
Batch
;
num_byte
+=
(
sizeof
(
ADataType
)
*
M
*
K
+
sizeof
(
B0DataType
)
*
K
*
N
+
sizeof
(
B1DataType
)
*
N
*
O
+
...
...
@@ -351,12 +358,12 @@ int run(int argc, char* argv[])
for
(
std
::
size_t
i
=
0
;
i
<
group_count
;
i
++
)
{
const
int
&
G0
=
g0_g1_m_n_k_o
[
i
][
0
];
const
int
&
G1
=
g0_g1_m_n_k_o
[
i
][
1
];
const
int
&
M
=
g0_g1_m_n_k_o
[
i
][
2
];
const
int
&
N
=
g0_g1_m_n_k_o
[
i
][
3
];
const
int
&
K
=
g0_g1_m_n_k_o
[
i
][
4
];
const
int
&
O
=
g0_g1_m_n_k_o
[
i
][
5
];
const
int
&
G0
=
g0_g1
q
_m_n_k_o
[
i
][
0
];
const
int
&
G1
Q
=
g0_g1
q
_m_n_k_o
[
i
][
1
];
const
int
&
M
=
g0_g1
q
_m_n_k_o
[
i
][
2
];
const
int
&
N
=
g0_g1
q
_m_n_k_o
[
i
][
3
];
const
int
&
K
=
g0_g1
q
_m_n_k_o
[
i
][
4
];
const
int
&
O
=
g0_g1
q
_m_n_k_o
[
i
][
5
];
const
auto
&
c_gs_ms_os_lengths
=
problem_descs
[
i
]
.
c_gs_ms_os_lengths
;
const
auto
&
c_gs_ms_os_strides
=
problem_descs
[
i
]
.
c_gs_ms_os_strides
;
...
...
@@ -378,36 +385,43 @@ int run(int argc, char* argv[])
z_gs_ms_ns_device_buf
.
FromDevice
(
z_gs_ms_ns_device_result
.
mData
.
data
());
lse_gs_ms_device_buf
.
FromDevice
(
lse_gs_ms_device_result
.
mData
.
data
());
Tensor
<
ADataType
>
a_g_m_k
({
G0
*
G1
,
M
,
K
});
Tensor
<
B0DataType
>
b0_g_k_n
({
G0
*
G1
,
K
,
N
});
Tensor
<
B1DataType
>
b1_g_n_o
({
G0
*
G1
,
N
,
O
});
Tensor
<
AccDataType
>
acc0_g_m_n
({
G0
*
G1
,
M
,
N
});
// scratch object after gemm0
Tensor
<
Acc0BiasDataType
>
d_g_m_n
({
G0
*
G1
,
M
,
N
});
Tensor
<
ADataType
>
a1_g_m_n
({
G0
*
G1
,
M
,
N
});
// scratch object after softmax
Tensor
<
ADataType
>
a1_g_m_n_drop
({
G0
*
G1
,
M
,
N
});
// scratch object after softmax
Tensor
<
CDataType
>
c_g_m_o_host_result
({
G0
*
G1
,
M
,
O
});
// scratch object after gemm1
Tensor
<
ADataType
>
a_g_m_k
({
G0
*
G1
Q
,
M
,
K
});
Tensor
<
B0DataType
>
b0_g_k_n
({
G0
*
G1
Q
,
K
,
N
});
Tensor
<
B1DataType
>
b1_g_n_o
({
G0
*
G1
Q
,
N
,
O
});
Tensor
<
AccDataType
>
acc0_g_m_n
({
G0
*
G1
Q
,
M
,
N
});
// scratch object after gemm0
Tensor
<
Acc0BiasDataType
>
d_g_m_n
({
G0
*
G1
Q
,
M
,
N
});
Tensor
<
ADataType
>
a1_g_m_n
({
G0
*
G1
Q
,
M
,
N
});
// scratch object after softmax
Tensor
<
ADataType
>
a1_g_m_n_drop
({
G0
*
G1
Q
,
M
,
N
});
// scratch object after softmax
Tensor
<
CDataType
>
c_g_m_o_host_result
({
G0
*
G1
Q
,
M
,
O
});
// scratch object after gemm1
Tensor
<
CDataType
>
c_gs_ms_os_host_result
(
c_gs_ms_os_lengths
,
c_gs_ms_os_strides
);
Tensor
<
ZDataType
>
z_g_m_n
({
G0
*
G1
,
M
,
N
});
Tensor
<
LSEDataType
>
lse_g_m_host_result
({
G0
*
G1
,
M
});
// scratch object after gemm1
Tensor
<
ZDataType
>
z_g_m_n
({
G0
*
G1
Q
,
M
,
N
});
Tensor
<
LSEDataType
>
lse_g_m_host_result
({
G0
*
G1
Q
,
M
});
// scratch object after gemm1
Tensor
<
LSEDataType
>
lse_gs_ms_host_result
(
lse_gs_ms_lengths
,
lse_gs_ms_strides
);
// permute
a_gs_ms_ks
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
a_g_m_k
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
a_g_m_k
(
idx
[
0
]
*
G1
Q
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
});
b0_gs_ns_ks
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
b0_g_k_n
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
3
],
idx
[
2
])
=
self
(
idx
);
});
b1_gs_os_ns
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
b1_g_n_o
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
3
],
idx
[
2
])
=
self
(
idx
);
b0_g_k_n
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
const
size_t
&
g0
=
idx
[
0
]
/
G1Q
;
const
size_t
&
g1q
=
idx
[
0
]
%
G1Q
;
const
size_t
&
g1kv
=
g1q
/
h_ratio
;
self
(
idx
)
=
b0_gs_ns_ks
(
g0
,
g1kv
,
idx
[
2
],
idx
[
1
]);
});
b1_g_n_o
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
const
size_t
&
g0
=
idx
[
0
]
/
G1Q
;
const
size_t
&
g1q
=
idx
[
0
]
%
G1Q
;
const
size_t
&
g1kv
=
g1q
/
h_ratio
;
self
(
idx
)
=
b1_gs_os_ns
(
g0
,
g1kv
,
idx
[
2
],
idx
[
1
]);
});
d_gs_ms_ns
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
d_g_m_n
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
d_g_m_n
(
idx
[
0
]
*
G1
Q
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
});
z_gs_ms_ns_device_result
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
z_g_m_n
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
z_g_m_n
(
idx
[
0
]
*
G1
Q
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
});
// gemm 0
...
...
@@ -461,18 +475,18 @@ int run(int argc, char* argv[])
// permute
c_gs_ms_os_host_result
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
const
size_t
&
g0
=
idx
[
0
];
const
size_t
&
g1
=
idx
[
1
];
const
size_t
&
g0
=
idx
[
0
];
const
size_t
&
g1
q
=
idx
[
1
];
const
size_t
g
=
g0
*
G1
+
g1
;
const
size_t
g
=
g0
*
G1
Q
+
g1
q
;
self
(
idx
)
=
c_g_m_o_host_result
(
g
,
idx
[
2
],
idx
[
3
]);
});
lse_gs_ms_host_result
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
const
size_t
&
g0
=
idx
[
0
];
const
size_t
&
g1
=
idx
[
1
];
const
size_t
&
g0
=
idx
[
0
];
const
size_t
&
g1
q
=
idx
[
1
];
const
size_t
g
=
g0
*
G1
+
g1
;
const
size_t
g
=
g0
*
G1
Q
+
g1
q
;
self
(
idx
)
=
lse_g_m_host_result
(
g
,
idx
[
2
]);
});
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_bwd_xdl_cshuffle_qloop_light_v1.hpp
View file @
e87ddb0e
...
...
@@ -132,14 +132,17 @@ __global__ void
const
CElementwiseOperation
c_element_op
,
const
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1
,
const
BGridDesc_BK0_N_BK1
b_grid_desc_bk0_n_bk1
,
const
BGridDesc_BK0_N_BK1
bgrad_grid_desc_bk0_n_bk1
,
const
D0GridDescriptor_M0_N0_M1_M2_N1_M3
d0_grid_desc_m0_n0_m1_m2_n1_m3
,
const
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
,
const
B1GridDesc_BK0_N_BK1
b1_grid_desc_bk0_n_bk1
,
const
B1GridDesc_BK0_N_BK1
b1grad_grid_desc_bk0_n_bk1
,
const
LSEGridDescriptor_M
lse_grid_desc_m
,
const
YGradGridDesc_O0_M_O1
ygrad_grid_desc_o0_m_o1
,
const
Block2CTileMap
block_2_ctile_map
,
const
index_t
batch_count
,
const
index_t
h_ratio
,
const
index_t
nblock
,
const
ComputeBasePtrOfStridedBatch
compute_base_ptr_of_batch
,
const
C0MatrixMask
c0_matrix_mask
,
...
...
@@ -155,21 +158,26 @@ __global__ void
const
index_t
num_blocks_per_batch
=
__builtin_amdgcn_readfirstlane
(
get_grid_size
()
/
batch_count
);
const
index_t
g_idx
=
__builtin_amdgcn_readfirstlane
(
get_block_1d_id
()
/
num_blocks_per_batch
);
const
index_t
gkv_idx
=
__builtin_amdgcn_readfirstlane
(
g_idx
/
h_ratio
);
// NOTE: assumes QKVY has the same layout as dQ/dK/dV/dY therefore being able to reuse batch
// offsets
const
long_index_t
a_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_base_ptr_of_batch
.
GetABasePtr
(
g_idx
)));
const
long_index_t
b_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_base_ptr_of_batch
.
GetBBasePtr
(
g_idx
)));
static_cast
<
long_index_t
>
(
compute_base_ptr_of_batch
.
GetBBasePtr
(
g
kv
_idx
)));
const
long_index_t
z_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_base_ptr_of_batch
.
GetZBasePtr
(
g_idx
)));
const
long_index_t
b1_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_base_ptr_of_batch
.
GetB1BasePtr
(
g_idx
)));
static_cast
<
long_index_t
>
(
compute_base_ptr_of_batch
.
GetB1BasePtr
(
g
kv
_idx
)));
const
long_index_t
c_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_base_ptr_of_batch
.
GetCBasePtr
(
g_idx
)));
const
long_index_t
lse_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_base_ptr_of_batch
.
GetLSEBasePtr
(
g_idx
)));
const
long_index_t
bgrad_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_base_ptr_of_batch
.
GetBGradBasePtr
(
g_idx
)));
const
long_index_t
b1grad_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_base_ptr_of_batch
.
GetB1GradBasePtr
(
g_idx
)));
ck
::
philox
ph
(
seed
,
0
,
offset
);
ZDataType
*
z_matrix_ptr
=
(
p_z_grid
==
nullptr
?
nullptr
:
p_z_grid
+
z_batch_offset
);
...
...
@@ -205,9 +213,9 @@ __global__ void
p_d_grid
+
lse_batch_offset
,
p_ygrad_grid
+
c_batch_offset
,
p_qgrad_grid
+
a_batch_offset
,
p_kgrad_grid
+
b_batch_offset
,
p_kgrad_grid
+
b
grad
_batch_offset
,
tmp_p_d0grad_grid
,
p_vgrad_grid
+
b1_batch_offset
,
p_vgrad_grid
+
b1
grad
_batch_offset
,
p_shared
,
a_element_op
,
b_element_op
,
...
...
@@ -216,9 +224,11 @@ __global__ void
c_element_op
,
a_grid_desc_ak0_m_ak1
,
b_grid_desc_bk0_n_bk1
,
bgrad_grid_desc_bk0_n_bk1
,
d0_grid_desc_m0_n0_m1_m2_n1_m3
,
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
,
b1_grid_desc_bk0_n_bk1
,
b1grad_grid_desc_bk0_n_bk1
,
lse_grid_desc_m
,
ygrad_grid_desc_o0_m_o1
,
block_2_ctile_map
,
...
...
@@ -242,9 +252,9 @@ __global__ void
p_d_grid
+
lse_batch_offset
,
p_ygrad_grid
+
c_batch_offset
,
p_qgrad_grid
+
a_batch_offset
,
p_kgrad_grid
+
b_batch_offset
,
p_kgrad_grid
+
b
grad
_batch_offset
,
tmp_p_d0grad_grid
,
p_vgrad_grid
+
b1_batch_offset
,
p_vgrad_grid
+
b1
grad
_batch_offset
,
p_shared
,
a_element_op
,
b_element_op
,
...
...
@@ -253,9 +263,11 @@ __global__ void
c_element_op
,
a_grid_desc_ak0_m_ak1
,
b_grid_desc_bk0_n_bk1
,
bgrad_grid_desc_bk0_n_bk1
,
d0_grid_desc_m0_n0_m1_m2_n1_m3
,
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
,
b1_grid_desc_bk0_n_bk1
,
b1grad_grid_desc_bk0_n_bk1
,
lse_grid_desc_m
,
ygrad_grid_desc_o0_m_o1
,
block_2_ctile_map
,
...
...
@@ -286,13 +298,16 @@ __global__ void
ignore
=
c_element_op
;
ignore
=
a_grid_desc_ak0_m_ak1
;
ignore
=
b_grid_desc_bk0_n_bk1
;
ignore
=
bgrad_grid_desc_bk0_n_bk1
;
ignore
=
d0_grid_desc_m0_n0_m1_m2_n1_m3
;
ignore
=
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
;
ignore
=
b1_grid_desc_bk0_n_bk1
;
ignore
=
b1grad_grid_desc_bk0_n_bk1
;
ignore
=
lse_grid_desc_m
;
ignore
=
ygrad_grid_desc_o0_m_o1
;
ignore
=
block_2_ctile_map
;
ignore
=
batch_count
;
ignore
=
h_ratio
;
ignore
=
nblock
;
ignore
=
compute_base_ptr_of_batch
;
ignore
=
c0_matrix_mask
;
...
...
@@ -695,6 +710,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
const
ZGridDesc_G_M_N
&
z_grid_desc_g_m_n
,
const
B1GridDesc_G_N_K
&
b1_grid_desc_g_n_k
,
const
CGridDesc_G_M_N
&
c_grid_desc_g_m_n
,
const
BGridDesc_G_N_K
&
bgrad_grid_desc_g_n_k
,
const
B1GridDesc_G_N_K
&
b1grad_grid_desc_g_n_k
,
index_t
BatchStrideLSE
)
:
a_grid_desc_g_m_k_
(
a_grid_desc_g_m_k
),
b_grid_desc_g_n_k_
(
b_grid_desc_g_n_k
),
...
...
@@ -702,6 +719,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
z_grid_desc_g_m_n_
(
z_grid_desc_g_m_n
),
b1_grid_desc_g_n_k_
(
b1_grid_desc_g_n_k
),
c_grid_desc_g_m_n_
(
c_grid_desc_g_m_n
),
bgrad_grid_desc_g_n_k_
(
bgrad_grid_desc_g_n_k
),
b1grad_grid_desc_g_n_k_
(
b1grad_grid_desc_g_n_k
),
BatchStrideLSE_
(
BatchStrideLSE
)
{
}
...
...
@@ -741,6 +760,16 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
return
g_idx
*
static_cast
<
long_index_t
>
(
BatchStrideLSE_
);
}
__host__
__device__
constexpr
long_index_t
GetBGradBasePtr
(
index_t
g_idx
)
const
{
return
bgrad_grid_desc_g_n_k_
.
CalculateOffset
(
make_multi_index
(
g_idx
,
0
,
0
));
}
__host__
__device__
constexpr
long_index_t
GetB1GradBasePtr
(
index_t
g_idx
)
const
{
return
b1grad_grid_desc_g_n_k_
.
CalculateOffset
(
make_multi_index
(
g_idx
,
0
,
0
));
}
private:
AGridDesc_G_M_K
a_grid_desc_g_m_k_
;
BGridDesc_G_N_K
b_grid_desc_g_n_k_
;
...
...
@@ -748,6 +777,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
ZGridDesc_G_M_N
z_grid_desc_g_m_n_
;
B1GridDesc_G_N_K
b1_grid_desc_g_n_k_
;
CGridDesc_G_M_N
c_grid_desc_g_m_n_
;
BGridDesc_G_N_K
bgrad_grid_desc_g_n_k_
;
B1GridDesc_G_N_K
b1grad_grid_desc_g_n_k_
;
index_t
BatchStrideLSE_
;
};
...
...
@@ -858,6 +889,10 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
const
std
::
vector
<
index_t
>&
c_gs_ms_gemm1ns_lengths
,
// c_gs_ms_os_lengths
const
std
::
vector
<
index_t
>&
c_gs_ms_gemm1ns_strides
,
// c_gs_ms_os_strides
const
std
::
vector
<
index_t
>&
lse_gs_ms_lengths
,
const
std
::
vector
<
index_t
>&
bgrad_gs_ns_ks_lengths
,
const
std
::
vector
<
index_t
>&
bgrad_gs_ns_ks_strides
,
const
std
::
vector
<
index_t
>&
b1grad_gs_gemm1ns_gemm1ks_lengths
,
const
std
::
vector
<
index_t
>&
b1grad_gs_gemm1ns_gemm1ks_strides
,
const
std
::
vector
<
ck
::
index_t
>&
acc0_bias_gs_ms_ns_lengths
,
const
std
::
vector
<
ck
::
index_t
>&
acc0_bias_gs_ms_ns_strides
,
const
std
::
vector
<
ck
::
index_t
>&
...
...
@@ -888,9 +923,13 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
DeviceOp
::
MakeAGridDescriptor_AK0_M_AK1
(
a_gs_ms_ks_lengths
,
a_gs_ms_ks_strides
)},
b_grid_desc_bk0_n_bk1_
{
DeviceOp
::
MakeBGridDescriptor_BK0_N_BK1
(
b_gs_ns_ks_lengths
,
b_gs_ns_ks_strides
)},
bgrad_grid_desc_bk0_n_bk1_
{
DeviceOp
::
MakeBGridDescriptor_BK0_N_BK1
(
bgrad_gs_ns_ks_lengths
,
bgrad_gs_ns_ks_strides
)},
z_grid_desc_m_n_
{
MakeZGridDescriptor_M_N
(
z_gs_ms_ns_lengths
,
z_gs_ms_ns_strides
)},
b1_grid_desc_bk0_n_bk1_
{
DeviceOp
::
MakeVGridDescriptor_O0_N_O1
(
b1_gs_gemm1ns_gemm1ks_lengths
,
b1_gs_gemm1ns_gemm1ks_strides
)},
b1grad_grid_desc_bk0_n_bk1_
{
DeviceOp
::
MakeVGridDescriptor_O0_N_O1
(
b1grad_gs_gemm1ns_gemm1ks_lengths
,
b1grad_gs_gemm1ns_gemm1ks_strides
)},
y_grid_desc_m_o_
{
Transform
::
MakeCGridDescriptor_M_N
(
c_gs_ms_gemm1ns_lengths
,
c_gs_ms_gemm1ns_strides
)},
d_y_grid_desc_m_o_
{
DTransform
::
MakeCGridDescriptor_M_N
(
c_gs_ms_gemm1ns_lengths
,
...
...
@@ -912,6 +951,10 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
c_gs_ms_gemm1ns_strides
)},
z_grid_desc_g_m_n_
{
Transform
::
MakeC0GridDescriptor_G_M_N
(
z_gs_ms_ns_lengths
,
z_gs_ms_ns_strides
)},
bgrad_grid_desc_g_n_k_
{
Transform
::
MakeB0GridDescriptor_G_N_K
(
bgrad_gs_ns_ks_lengths
,
bgrad_gs_ns_ks_strides
)},
b1grad_grid_desc_g_n_k_
{
Transform
::
MakeB1GridDescriptor_G_N_K
(
b1grad_gs_gemm1ns_gemm1ks_lengths
,
b1grad_gs_gemm1ns_gemm1ks_strides
)},
block_2_ctile_map_
{
GridwiseGemm
::
MakeDefaultBlock2CTileMap
(
k_grid_desc_n_k_
)},
d_block_2_ctile_map_
{
GridwiseYDotYGrad
::
MakeDefaultBlock2CTileMap
(
d_y_grid_desc_m_o_
)},
...
...
@@ -935,6 +978,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
c_mz_gemm1nz_strides_
{
c_gs_ms_gemm1ns_strides
[
NumDimG
+
NumDimM
-
1
],
c_gs_ms_gemm1ns_strides
[
NumDimG
+
NumDimM
+
NumDimO
-
1
]},
batch_count_
{
c_grid_desc_g_m_n_
.
GetLength
(
I0
)},
h_ratio_
{
c_grid_desc_g_m_n_
.
GetLength
(
I0
)
/
b_grid_desc_g_n_k_
.
GetLength
(
I0
)},
p_drop_
{
p_drop
}
{
// TODO: implement bias addition
...
...
@@ -964,6 +1008,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
z_grid_desc_g_m_n_
,
b1_grid_desc_g_n_k_
,
c_grid_desc_g_m_n_
,
bgrad_grid_desc_g_n_k_
,
b1grad_grid_desc_g_n_k_
,
type_convert
<
index_t
>
(
lse_grid_desc_m_
.
GetElementSpaceSize
()));
seed_
=
std
::
get
<
0
>
(
seeds
);
...
...
@@ -990,7 +1036,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
<<
b_grid_desc_g_n_k_
.
GetLength
(
I1
)
<<
", "
<<
b_grid_desc_g_n_k_
.
GetLength
(
I2
)
<<
'\n'
;
// b_grid_desc_g_n_k_.Print();
std
::
cout
<<
"b1_grid_desc_g_
o_n
_: "
<<
b1_grid_desc_g_n_k_
.
GetLength
(
I0
)
<<
", "
std
::
cout
<<
"b1_grid_desc_g_
n_k
_: "
<<
b1_grid_desc_g_n_k_
.
GetLength
(
I0
)
<<
", "
<<
b1_grid_desc_g_n_k_
.
GetLength
(
I1
)
<<
", "
<<
b1_grid_desc_g_n_k_
.
GetLength
(
I2
)
<<
'\n'
;
// b1_grid_desc_g_n_k_.Print();
...
...
@@ -1003,6 +1049,17 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
std
::
cout
<<
"ygrad_grid_desc_o0_m_o1_: "
<<
ygrad_grid_desc_o0_m_o1_
.
GetLength
(
I0
)
<<
", "
<<
ygrad_grid_desc_o0_m_o1_
.
GetLength
(
I1
)
<<
", "
<<
ygrad_grid_desc_o0_m_o1_
.
GetLength
(
I2
)
<<
'\n'
;
std
::
cout
<<
"d0_grid_desc_g_m_n_: "
<<
d0_grid_desc_g_m_n_
.
GetLength
(
I0
)
<<
", "
<<
d0_grid_desc_g_m_n_
.
GetLength
(
I1
)
<<
", "
<<
d0_grid_desc_g_m_n_
.
GetLength
(
I2
)
<<
'\n'
;
std
::
cout
<<
"bgrad_grid_desc_g_n_k_: "
<<
bgrad_grid_desc_g_n_k_
.
GetLength
(
I0
)
<<
", "
<<
bgrad_grid_desc_g_n_k_
.
GetLength
(
I1
)
<<
", "
<<
bgrad_grid_desc_g_n_k_
.
GetLength
(
I2
)
<<
'\n'
;
// bgrad_grid_desc_g_n_k_.Print();
std
::
cout
<<
"b1grad_grid_desc_g_n_k_: "
<<
b1grad_grid_desc_g_n_k_
.
GetLength
(
I0
)
<<
", "
<<
b1grad_grid_desc_g_n_k_
.
GetLength
(
I1
)
<<
", "
<<
b1grad_grid_desc_g_n_k_
.
GetLength
(
I2
)
<<
'\n'
;
// b1grad_grid_desc_g_n_k_.Print();
}
// pointers
...
...
@@ -1023,9 +1080,11 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
// tensor descriptor
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1_
;
BGridDesc_BK0_N_BK1
b_grid_desc_bk0_n_bk1_
;
BGridDesc_BK0_N_BK1
bgrad_grid_desc_bk0_n_bk1_
;
typename
GridwiseGemm
::
D0GridDescriptor_M0_N0_M1_M2_N1_M3
d0_grid_desc_m0_n0_m1_m2_n1_m3_
;
ZGridDesc_M_N
z_grid_desc_m_n_
;
B1GridDesc_BK0_N_BK1
b1_grid_desc_bk0_n_bk1_
;
B1GridDesc_BK0_N_BK1
b1grad_grid_desc_bk0_n_bk1_
;
YGridDesc_M_O
y_grid_desc_m_o_
;
DYGridDesc_M_O
d_y_grid_desc_m_o_
;
LSEGridDesc_M
lse_grid_desc_m_
;
...
...
@@ -1040,6 +1099,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
B1GridDesc_G_N_K
b1_grid_desc_g_n_k_
;
CGridDesc_G_M_N
c_grid_desc_g_m_n_
;
ZGridDesc_G_M_N
z_grid_desc_g_m_n_
;
BGridDesc_G_N_K
bgrad_grid_desc_g_n_k_
;
B1GridDesc_G_N_K
b1grad_grid_desc_g_n_k_
;
typename
GridwiseGemm
::
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3_
;
...
...
@@ -1068,6 +1129,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
std
::
vector
<
index_t
>
c_mz_gemm1nz_strides_
;
index_t
batch_count_
;
index_t
h_ratio_
;
ComputeBasePtrOfStridedBatch
compute_base_ptr_of_batch_
;
float
p_drop_
;
...
...
@@ -1189,13 +1251,16 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
arg
.
c_element_op_
,
arg
.
a_grid_desc_ak0_m_ak1_
,
arg
.
b_grid_desc_bk0_n_bk1_
,
arg
.
bgrad_grid_desc_bk0_n_bk1_
,
arg
.
d0_grid_desc_m0_n0_m1_m2_n1_m3_
,
arg
.
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3_
,
arg
.
b1_grid_desc_bk0_n_bk1_
,
arg
.
b1grad_grid_desc_bk0_n_bk1_
,
arg
.
lse_grid_desc_m_
,
arg
.
ygrad_grid_desc_o0_m_o1_
,
arg
.
block_2_ctile_map_
,
arg
.
batch_count_
,
arg
.
h_ratio_
,
arg
.
block_2_ctile_map_
.
CalculateGridSize
(
arg
.
k_grid_desc_n_k_
),
arg
.
compute_base_ptr_of_batch_
,
arg
.
c0_matrix_mask_
,
...
...
@@ -1249,13 +1314,14 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
// Check if C permute dimension matches GEMM + GEMM shape
const
index_t
c_g
=
arg
.
c_grid_desc_g_m_n_
.
GetLength
(
I0
);
// unpadded
const
index_t
b_g
=
arg
.
b_grid_desc_g_n_k_
.
GetLength
(
I0
);
const
index_t
c_m
=
arg
.
y_grid_desc_m_o_
.
GetLength
(
I0
);
const
index_t
c_gemm1n
=
arg
.
y_grid_desc_m_o_
.
GetLength
(
I1
);
const
index_t
a_m
=
arg
.
a_grid_desc_ak0_m_ak1_
.
GetLength
(
I1
);
const
index_t
b1_gemm1n
=
arg
.
b1_grid_desc_bk0_n_bk1_
.
GetLength
(
I0
)
*
arg
.
b1_grid_desc_bk0_n_bk1_
.
GetLength
(
I2
);
if
(
!
(
c_g
==
arg
.
batch_count_
&&
c_m
==
a_m
&&
c_gemm1n
==
b1_gemm1n
))
if
(
!
(
c_g
==
arg
.
batch_count_
&&
c_m
==
a_m
&&
c_gemm1n
==
b1_gemm1n
&&
c_g
%
b_g
==
0
))
{
return
false
;
}
...
...
@@ -1293,6 +1359,16 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
return
false
;
}
// saving dQ data with atomic_add instruction, so KzRaw must be a multiple of 2
if
constexpr
(
is_same
<
OutputDataType
,
half_t
>::
value
||
is_same
<
OutputDataType
,
bhalf_t
>::
value
)
{
if
(
KzRaw
%
2
!=
0
)
{
std
::
cout
<<
"K_q must be a multiple of 2"
<<
std
::
endl
;
return
false
;
}
}
// Check vector load/store requirement
const
auto
a_stride_lowest
=
ABlockTransferSrcVectorDim
==
2
?
arg
.
a_mz_kz_strides_
[
1
]
:
arg
.
a_mz_kz_strides_
[
0
];
...
...
@@ -1345,6 +1421,10 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
const
std
::
vector
<
index_t
>&
c_gs_ms_gemm1ns_lengths
,
// c_gs_ms_os_lengths
const
std
::
vector
<
index_t
>&
c_gs_ms_gemm1ns_strides
,
// c_gs_ms_os_strides
const
std
::
vector
<
index_t
>&
lse_gs_ms_lengths
,
const
std
::
vector
<
index_t
>&
bgrad_gs_ns_ks_lengths
,
const
std
::
vector
<
index_t
>&
bgrad_gs_ns_ks_strides
,
const
std
::
vector
<
index_t
>&
b1grad_gs_gemm1ns_gemm1ks_lengths
,
const
std
::
vector
<
index_t
>&
b1grad_gs_gemm1ns_gemm1ks_strides
,
const
std
::
vector
<
ck
::
index_t
>&
acc0_bias_gs_ms_ns_lengths
,
const
std
::
vector
<
ck
::
index_t
>&
acc0_bias_gs_ms_ns_strides
,
const
std
::
vector
<
ck
::
index_t
>&
...
...
@@ -1385,6 +1465,10 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
c_gs_ms_gemm1ns_lengths
,
// c_gs_ms_os_lengths
c_gs_ms_gemm1ns_strides
,
// c_gs_ms_os_strides
lse_gs_ms_lengths
,
bgrad_gs_ns_ks_lengths
,
bgrad_gs_ns_ks_strides
,
b1grad_gs_gemm1ns_gemm1ks_lengths
,
b1grad_gs_gemm1ns_gemm1ks_strides
,
acc0_bias_gs_ms_ns_lengths
,
acc0_bias_gs_ms_ns_strides
,
acc1_bias_gs_ms_gemm1ns_lengths
,
// acc1_bias_gs_ms_os_lengths
...
...
@@ -1414,8 +1498,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
void
*
p_qgrad_grid
,
void
*
p_kgrad_grid
,
void
*
p_vgrad_grid
,
const
D0DataType
*
p_acc0_bias
,
const
D1DataType
*
p_acc1_bias
,
const
void
*
p_acc0_bias
,
const
void
*
p_acc1_bias
,
void
*
p_d0grad_grid
,
void
*
p_d1grad_grid
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_lengths
,
...
...
@@ -1429,6 +1513,10 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
const
std
::
vector
<
index_t
>&
c_gs_ms_gemm1ns_lengths
,
// c_gs_ms_os_lengths
const
std
::
vector
<
index_t
>&
c_gs_ms_gemm1ns_strides
,
// c_gs_ms_os_strides
const
std
::
vector
<
index_t
>&
lse_gs_ms_lengths
,
const
std
::
vector
<
index_t
>&
bgrad_gs_ns_ks_lengths
,
const
std
::
vector
<
index_t
>&
bgrad_gs_ns_ks_strides
,
const
std
::
vector
<
index_t
>&
b1grad_gs_gemm1ns_gemm1ks_lengths
,
const
std
::
vector
<
index_t
>&
b1grad_gs_gemm1ns_gemm1ks_strides
,
const
std
::
vector
<
ck
::
index_t
>&
acc0_bias_gs_ms_ns_lengths
,
const
std
::
vector
<
ck
::
index_t
>&
acc0_bias_gs_ms_ns_strides
,
const
std
::
vector
<
ck
::
index_t
>&
...
...
@@ -1470,6 +1558,10 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
c_gs_ms_gemm1ns_lengths
,
// c_gs_ms_os_lengths
c_gs_ms_gemm1ns_strides
,
// c_gs_ms_os_strides
lse_gs_ms_lengths
,
bgrad_gs_ns_ks_lengths
,
bgrad_gs_ns_ks_strides
,
b1grad_gs_gemm1ns_gemm1ks_lengths
,
b1grad_gs_gemm1ns_gemm1ks_strides
,
acc0_bias_gs_ms_ns_lengths
,
acc0_bias_gs_ms_ns_strides
,
acc1_bias_gs_ms_gemm1ns_lengths
,
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_bwd_xdl_cshuffle_qloop_light_v2.hpp
View file @
e87ddb0e
...
...
@@ -132,14 +132,17 @@ __global__ void
const
CElementwiseOperation
c_element_op
,
const
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1
,
const
BGridDesc_BK0_N_BK1
b_grid_desc_bk0_n_bk1
,
const
BGridDesc_BK0_N_BK1
bgrad_grid_desc_bk0_n_bk1
,
const
D0GridDescriptor_M0_N0_M1_M2_N1_M3
d0_grid_desc_m0_n0_m1_m2_n1_m3
,
const
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
,
const
B1GridDesc_BK0_N_BK1
b1_grid_desc_bk0_n_bk1
,
const
B1GridDesc_BK0_N_BK1
b1grad_grid_desc_bk0_n_bk1
,
const
LSEGridDescriptor_M
lse_grid_desc_m
,
const
YGradGridDesc_M0_O_M1
ygrad_grid_desc_m0_o_m1
,
const
Block2CTileMap
block_2_ctile_map
,
const
index_t
batch_count
,
const
index_t
h_ratio
,
const
index_t
nblock
,
const
ComputeBasePtrOfStridedBatch
compute_base_ptr_of_batch
,
const
C0MatrixMask
c0_matrix_mask
,
...
...
@@ -155,21 +158,26 @@ __global__ void
const
index_t
num_blocks_per_batch
=
__builtin_amdgcn_readfirstlane
(
get_grid_size
()
/
batch_count
);
const
index_t
g_idx
=
__builtin_amdgcn_readfirstlane
(
get_block_1d_id
()
/
num_blocks_per_batch
);
const
index_t
gkv_idx
=
__builtin_amdgcn_readfirstlane
(
g_idx
/
h_ratio
);
// NOTE: assumes QKVY has the same layout as dQ/dK/dV/dY therefore being able to reuse batch
// offsets
const
long_index_t
a_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_base_ptr_of_batch
.
GetABasePtr
(
g_idx
)));
const
long_index_t
b_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_base_ptr_of_batch
.
GetBBasePtr
(
g_idx
)));
static_cast
<
long_index_t
>
(
compute_base_ptr_of_batch
.
GetBBasePtr
(
g
kv
_idx
)));
const
long_index_t
z_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_base_ptr_of_batch
.
GetZBasePtr
(
g_idx
)));
const
long_index_t
b1_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_base_ptr_of_batch
.
GetB1BasePtr
(
g_idx
)));
static_cast
<
long_index_t
>
(
compute_base_ptr_of_batch
.
GetB1BasePtr
(
g
kv
_idx
)));
const
long_index_t
c_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_base_ptr_of_batch
.
GetCBasePtr
(
g_idx
)));
const
long_index_t
lse_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_base_ptr_of_batch
.
GetLSEBasePtr
(
g_idx
)));
const
long_index_t
bgrad_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_base_ptr_of_batch
.
GetBGradBasePtr
(
g_idx
)));
const
long_index_t
b1grad_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_base_ptr_of_batch
.
GetB1GradBasePtr
(
g_idx
)));
ck
::
philox
ph
(
seed
,
0
,
offset
);
ZDataType
*
z_matrix_ptr
=
(
p_z_grid
==
nullptr
?
nullptr
:
p_z_grid
+
z_batch_offset
);
...
...
@@ -206,9 +214,9 @@ __global__ void
p_d_grid
+
lse_batch_offset
,
p_ygrad_grid
+
c_batch_offset
,
p_qgrad_grid
+
a_batch_offset
,
p_kgrad_grid
+
b_batch_offset
,
p_kgrad_grid
+
b
grad
_batch_offset
,
tmp_p_d0grad_grid
,
p_vgrad_grid
+
b1_batch_offset
,
p_vgrad_grid
+
b1
grad
_batch_offset
,
p_shared
,
a_element_op
,
b_element_op
,
...
...
@@ -217,9 +225,11 @@ __global__ void
c_element_op
,
a_grid_desc_ak0_m_ak1
,
b_grid_desc_bk0_n_bk1
,
bgrad_grid_desc_bk0_n_bk1
,
d0_grid_desc_m0_n0_m1_m2_n1_m3
,
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
,
b1_grid_desc_bk0_n_bk1
,
b1grad_grid_desc_bk0_n_bk1
,
lse_grid_desc_m
,
ygrad_grid_desc_m0_o_m1
,
block_2_ctile_map
,
...
...
@@ -243,9 +253,9 @@ __global__ void
p_d_grid
+
lse_batch_offset
,
p_ygrad_grid
+
c_batch_offset
,
p_qgrad_grid
+
a_batch_offset
,
p_kgrad_grid
+
b_batch_offset
,
p_kgrad_grid
+
b
grad
_batch_offset
,
tmp_p_d0grad_grid
,
p_vgrad_grid
+
b1_batch_offset
,
p_vgrad_grid
+
b1
grad
_batch_offset
,
p_shared
,
a_element_op
,
b_element_op
,
...
...
@@ -254,9 +264,11 @@ __global__ void
c_element_op
,
a_grid_desc_ak0_m_ak1
,
b_grid_desc_bk0_n_bk1
,
bgrad_grid_desc_bk0_n_bk1
,
d0_grid_desc_m0_n0_m1_m2_n1_m3
,
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
,
b1_grid_desc_bk0_n_bk1
,
b1grad_grid_desc_bk0_n_bk1
,
lse_grid_desc_m
,
ygrad_grid_desc_m0_o_m1
,
block_2_ctile_map
,
...
...
@@ -287,13 +299,16 @@ __global__ void
ignore
=
c_element_op
;
ignore
=
a_grid_desc_ak0_m_ak1
;
ignore
=
b_grid_desc_bk0_n_bk1
;
ignore
=
bgrad_grid_desc_bk0_n_bk1
;
ignore
=
d0_grid_desc_m0_n0_m1_m2_n1_m3
;
ignore
=
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
;
ignore
=
b1_grid_desc_bk0_n_bk1
;
ignore
=
b1grad_grid_desc_bk0_n_bk1
;
ignore
=
lse_grid_desc_m
;
ignore
=
ygrad_grid_desc_m0_o_m1
;
ignore
=
block_2_ctile_map
;
ignore
=
batch_count
;
ignore
=
h_ratio
;
ignore
=
nblock
;
ignore
=
compute_base_ptr_of_batch
;
ignore
=
c0_matrix_mask
;
...
...
@@ -704,6 +719,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
const
ZGridDesc_G_M_N
&
z_grid_desc_g_m_n
,
const
B1GridDesc_G_N_K
&
b1_grid_desc_g_n_k
,
const
CGridDesc_G_M_N
&
c_grid_desc_g_m_n
,
const
BGridDesc_G_N_K
&
bgrad_grid_desc_g_n_k
,
const
B1GridDesc_G_N_K
&
b1grad_grid_desc_g_n_k
,
index_t
BatchStrideLSE
)
:
a_grid_desc_g_m_k_
(
a_grid_desc_g_m_k
),
b_grid_desc_g_n_k_
(
b_grid_desc_g_n_k
),
...
...
@@ -711,6 +728,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
z_grid_desc_g_m_n_
(
z_grid_desc_g_m_n
),
b1_grid_desc_g_n_k_
(
b1_grid_desc_g_n_k
),
c_grid_desc_g_m_n_
(
c_grid_desc_g_m_n
),
bgrad_grid_desc_g_n_k_
(
bgrad_grid_desc_g_n_k
),
b1grad_grid_desc_g_n_k_
(
b1grad_grid_desc_g_n_k
),
BatchStrideLSE_
(
BatchStrideLSE
)
{
}
...
...
@@ -729,6 +748,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
{
return
d0_grid_desc_g_m_n_
.
CalculateOffset
(
make_multi_index
(
g_idx
,
0
,
0
));
}
__host__
__device__
constexpr
long_index_t
GetZBasePtr
(
index_t
g_idx
)
const
{
return
z_grid_desc_g_m_n_
.
CalculateOffset
(
make_multi_index
(
g_idx
,
0
,
0
));
...
...
@@ -749,6 +769,16 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
return
g_idx
*
static_cast
<
long_index_t
>
(
BatchStrideLSE_
);
}
__host__
__device__
constexpr
long_index_t
GetBGradBasePtr
(
index_t
g_idx
)
const
{
return
bgrad_grid_desc_g_n_k_
.
CalculateOffset
(
make_multi_index
(
g_idx
,
0
,
0
));
}
__host__
__device__
constexpr
long_index_t
GetB1GradBasePtr
(
index_t
g_idx
)
const
{
return
b1grad_grid_desc_g_n_k_
.
CalculateOffset
(
make_multi_index
(
g_idx
,
0
,
0
));
}
private:
AGridDesc_G_M_K
a_grid_desc_g_m_k_
;
BGridDesc_G_N_K
b_grid_desc_g_n_k_
;
...
...
@@ -756,6 +786,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
ZGridDesc_G_M_N
z_grid_desc_g_m_n_
;
B1GridDesc_G_N_K
b1_grid_desc_g_n_k_
;
CGridDesc_G_M_N
c_grid_desc_g_m_n_
;
BGridDesc_G_N_K
bgrad_grid_desc_g_n_k_
;
B1GridDesc_G_N_K
b1grad_grid_desc_g_n_k_
;
index_t
BatchStrideLSE_
;
};
...
...
@@ -874,6 +906,10 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
const
std
::
vector
<
index_t
>&
c_gs_ms_gemm1ns_lengths
,
// c_gs_ms_os_lengths
const
std
::
vector
<
index_t
>&
c_gs_ms_gemm1ns_strides
,
// c_gs_ms_os_strides
const
std
::
vector
<
index_t
>&
lse_gs_ms_lengths
,
const
std
::
vector
<
index_t
>&
bgrad_gs_ns_ks_lengths
,
const
std
::
vector
<
index_t
>&
bgrad_gs_ns_ks_strides
,
const
std
::
vector
<
index_t
>&
b1grad_gs_gemm1ns_gemm1ks_lengths
,
const
std
::
vector
<
index_t
>&
b1grad_gs_gemm1ns_gemm1ks_strides
,
const
std
::
vector
<
ck
::
index_t
>&
acc0_bias_gs_ms_ns_lengths
,
const
std
::
vector
<
ck
::
index_t
>&
acc0_bias_gs_ms_ns_strides
,
const
std
::
vector
<
ck
::
index_t
>&
...
...
@@ -904,9 +940,13 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
DeviceOp
::
MakeAGridDescriptor_AK0_M_AK1
(
a_gs_ms_ks_lengths
,
a_gs_ms_ks_strides
)},
b_grid_desc_bk0_n_bk1_
{
DeviceOp
::
MakeBGridDescriptor_BK0_N_BK1
(
b_gs_ns_ks_lengths
,
b_gs_ns_ks_strides
)},
bgrad_grid_desc_bk0_n_bk1_
{
DeviceOp
::
MakeBGridDescriptor_BK0_N_BK1
(
bgrad_gs_ns_ks_lengths
,
bgrad_gs_ns_ks_strides
)},
z_grid_desc_m_n_
{
MakeZGridDescriptor_M_N
(
z_gs_ms_ns_lengths
,
z_gs_ms_ns_strides
)},
b1_grid_desc_bk0_n_bk1_
{
DeviceOp
::
MakeVGridDescriptor_O0_N_O1
(
b1_gs_gemm1ns_gemm1ks_lengths
,
b1_gs_gemm1ns_gemm1ks_strides
)},
b1grad_grid_desc_bk0_n_bk1_
{
DeviceOp
::
MakeVGridDescriptor_O0_N_O1
(
b1grad_gs_gemm1ns_gemm1ks_lengths
,
b1grad_gs_gemm1ns_gemm1ks_strides
)},
y_grid_desc_m_o_
{
Transform
::
MakeCGridDescriptor_M_N
(
c_gs_ms_gemm1ns_lengths
,
c_gs_ms_gemm1ns_strides
)},
d_y_grid_desc_m_o_
{
DTransform
::
MakeCGridDescriptor_M_N
(
c_gs_ms_gemm1ns_lengths
,
...
...
@@ -927,6 +967,10 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
c_gs_ms_gemm1ns_strides
)},
z_grid_desc_g_m_n_
{
Transform
::
MakeC0GridDescriptor_G_M_N
(
z_gs_ms_ns_lengths
,
z_gs_ms_ns_strides
)},
bgrad_grid_desc_g_n_k_
{
Transform
::
MakeB0GridDescriptor_G_N_K
(
bgrad_gs_ns_ks_lengths
,
bgrad_gs_ns_ks_strides
)},
b1grad_grid_desc_g_n_k_
{
Transform
::
MakeB1GridDescriptor_G_N_K
(
b1grad_gs_gemm1ns_gemm1ks_lengths
,
b1grad_gs_gemm1ns_gemm1ks_strides
)},
block_2_ctile_map_
{
GridwiseGemm
::
MakeDefaultBlock2CTileMap
(
k_grid_desc_n_k_
)},
d_block_2_ctile_map_
{
GridwiseYDotYGrad
::
MakeDefaultBlock2CTileMap
(
d_y_grid_desc_m_o_
)},
...
...
@@ -950,6 +994,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
c_mz_gemm1nz_strides_
{
c_gs_ms_gemm1ns_strides
[
NumDimG
+
NumDimM
-
1
],
c_gs_ms_gemm1ns_strides
[
NumDimG
+
NumDimM
+
NumDimO
-
1
]},
batch_count_
{
c_grid_desc_g_m_n_
.
GetLength
(
I0
)},
h_ratio_
{
c_grid_desc_g_m_n_
.
GetLength
(
I0
)
/
b_grid_desc_g_n_k_
.
GetLength
(
I0
)},
p_drop_
{
p_drop
}
{
// TODO: implement bias addition
...
...
@@ -979,6 +1024,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
z_grid_desc_g_m_n_
,
b1_grid_desc_g_n_k_
,
c_grid_desc_g_m_n_
,
bgrad_grid_desc_g_n_k_
,
b1grad_grid_desc_g_n_k_
,
type_convert
<
index_t
>
(
lse_grid_desc_m_
.
GetElementSpaceSize
()));
seed_
=
std
::
get
<
0
>
(
seeds
);
...
...
@@ -1005,7 +1052,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
<<
b_grid_desc_g_n_k_
.
GetLength
(
I1
)
<<
", "
<<
b_grid_desc_g_n_k_
.
GetLength
(
I2
)
<<
'\n'
;
// b_grid_desc_g_n_k_.Print();
std
::
cout
<<
"b1_grid_desc_g_
o_n
_: "
<<
b1_grid_desc_g_n_k_
.
GetLength
(
I0
)
<<
", "
std
::
cout
<<
"b1_grid_desc_g_
n_k
_: "
<<
b1_grid_desc_g_n_k_
.
GetLength
(
I0
)
<<
", "
<<
b1_grid_desc_g_n_k_
.
GetLength
(
I1
)
<<
", "
<<
b1_grid_desc_g_n_k_
.
GetLength
(
I2
)
<<
'\n'
;
// b1_grid_desc_g_n_k_.Print();
...
...
@@ -1018,6 +1065,17 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
std
::
cout
<<
"ygrad_grid_desc_m0_o_m1_: "
<<
ygrad_grid_desc_m0_o_m1_
.
GetLength
(
I0
)
<<
", "
<<
ygrad_grid_desc_m0_o_m1_
.
GetLength
(
I1
)
<<
", "
<<
ygrad_grid_desc_m0_o_m1_
.
GetLength
(
I2
)
<<
'\n'
;
std
::
cout
<<
"d0_grid_desc_g_m_n_: "
<<
d0_grid_desc_g_m_n_
.
GetLength
(
I0
)
<<
", "
<<
d0_grid_desc_g_m_n_
.
GetLength
(
I1
)
<<
", "
<<
d0_grid_desc_g_m_n_
.
GetLength
(
I2
)
<<
'\n'
;
std
::
cout
<<
"bgrad_grid_desc_g_n_k_: "
<<
bgrad_grid_desc_g_n_k_
.
GetLength
(
I0
)
<<
", "
<<
bgrad_grid_desc_g_n_k_
.
GetLength
(
I1
)
<<
", "
<<
bgrad_grid_desc_g_n_k_
.
GetLength
(
I2
)
<<
'\n'
;
// bgrad_grid_desc_g_n_k_.Print();
std
::
cout
<<
"b1grad_grid_desc_g_n_k_: "
<<
b1grad_grid_desc_g_n_k_
.
GetLength
(
I0
)
<<
", "
<<
b1grad_grid_desc_g_n_k_
.
GetLength
(
I1
)
<<
", "
<<
b1grad_grid_desc_g_n_k_
.
GetLength
(
I2
)
<<
'\n'
;
// b1grad_grid_desc_g_n_k_.Print();
}
// pointers
...
...
@@ -1038,9 +1096,11 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
// tensor descriptor
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1_
;
BGridDesc_BK0_N_BK1
b_grid_desc_bk0_n_bk1_
;
BGridDesc_BK0_N_BK1
bgrad_grid_desc_bk0_n_bk1_
;
typename
GridwiseGemm
::
D0GridDescriptor_M0_N0_M1_M2_N1_M3
d0_grid_desc_m0_n0_m1_m2_n1_m3_
;
ZGridDesc_M_N
z_grid_desc_m_n_
;
B1GridDesc_BK0_N_BK1
b1_grid_desc_bk0_n_bk1_
;
B1GridDesc_BK0_N_BK1
b1grad_grid_desc_bk0_n_bk1_
;
YGridDesc_M_O
y_grid_desc_m_o_
;
DYGridDesc_M_O
d_y_grid_desc_m_o_
;
LSEGridDesc_M
lse_grid_desc_m_
;
...
...
@@ -1055,6 +1115,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
B1GridDesc_G_N_K
b1_grid_desc_g_n_k_
;
CGridDesc_G_M_N
c_grid_desc_g_m_n_
;
ZGridDesc_G_M_N
z_grid_desc_g_m_n_
;
BGridDesc_G_N_K
bgrad_grid_desc_g_n_k_
;
B1GridDesc_G_N_K
b1grad_grid_desc_g_n_k_
;
typename
GridwiseGemm
::
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3_
;
...
...
@@ -1083,6 +1145,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
std
::
vector
<
index_t
>
c_mz_gemm1nz_strides_
;
index_t
batch_count_
;
index_t
h_ratio_
;
ComputeBasePtrOfStridedBatch
compute_base_ptr_of_batch_
;
float
p_drop_
;
...
...
@@ -1208,13 +1271,16 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
arg
.
c_element_op_
,
arg
.
a_grid_desc_ak0_m_ak1_
,
arg
.
b_grid_desc_bk0_n_bk1_
,
arg
.
bgrad_grid_desc_bk0_n_bk1_
,
arg
.
d0_grid_desc_m0_n0_m1_m2_n1_m3_
,
arg
.
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3_
,
arg
.
b1_grid_desc_bk0_n_bk1_
,
arg
.
b1grad_grid_desc_bk0_n_bk1_
,
arg
.
lse_grid_desc_m_
,
arg
.
ygrad_grid_desc_m0_o_m1_
,
arg
.
block_2_ctile_map_
,
arg
.
batch_count_
,
arg
.
h_ratio_
,
arg
.
block_2_ctile_map_
.
CalculateGridSize
(
arg
.
k_grid_desc_n_k_
),
arg
.
compute_base_ptr_of_batch_
,
arg
.
c0_matrix_mask_
,
...
...
@@ -1280,13 +1346,14 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
// Check if C permute dimension matches GEMM + GEMM shape
const
index_t
c_g
=
arg
.
c_grid_desc_g_m_n_
.
GetLength
(
I0
);
// unpadded
const
index_t
b_g
=
arg
.
b_grid_desc_g_n_k_
.
GetLength
(
I0
);
const
index_t
c_m
=
arg
.
y_grid_desc_m_o_
.
GetLength
(
I0
);
const
index_t
c_gemm1n
=
arg
.
y_grid_desc_m_o_
.
GetLength
(
I1
);
const
index_t
a_m
=
arg
.
a_grid_desc_ak0_m_ak1_
.
GetLength
(
I1
);
const
index_t
b1_gemm1n
=
arg
.
b1_grid_desc_bk0_n_bk1_
.
GetLength
(
I0
)
*
arg
.
b1_grid_desc_bk0_n_bk1_
.
GetLength
(
I2
);
if
(
!
(
c_g
==
arg
.
batch_count_
&&
c_m
==
a_m
&&
c_gemm1n
==
b1_gemm1n
))
if
(
!
(
c_g
==
arg
.
batch_count_
&&
c_m
==
a_m
&&
c_gemm1n
==
b1_gemm1n
&&
c_g
%
b_g
==
0
))
{
return
false
;
}
...
...
@@ -1325,6 +1392,16 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
return
false
;
}
// saving dQ data with atomic_add instruction, so KzRaw must be a multiple of 2
if
constexpr
(
is_same
<
OutputDataType
,
half_t
>::
value
||
is_same
<
OutputDataType
,
bhalf_t
>::
value
)
{
if
(
KzRaw
%
2
!=
0
)
{
std
::
cout
<<
"K_q must be a multiple of 2"
<<
std
::
endl
;
return
false
;
}
}
// Check vector load/store requirement
const
auto
a_stride_lowest
=
ABlockTransferSrcVectorDim
==
2
?
arg
.
a_mz_kz_strides_
[
1
]
:
arg
.
a_mz_kz_strides_
[
0
];
...
...
@@ -1380,6 +1457,10 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
const
std
::
vector
<
index_t
>&
c_gs_ms_gemm1ns_lengths
,
// c_gs_ms_os_lengths
const
std
::
vector
<
index_t
>&
c_gs_ms_gemm1ns_strides
,
// c_gs_ms_os_strides
const
std
::
vector
<
index_t
>&
lse_gs_ms_lengths
,
const
std
::
vector
<
index_t
>&
bgrad_gs_ns_ks_lengths
,
const
std
::
vector
<
index_t
>&
bgrad_gs_ns_ks_strides
,
const
std
::
vector
<
index_t
>&
b1grad_gs_gemm1ns_gemm1ks_lengths
,
const
std
::
vector
<
index_t
>&
b1grad_gs_gemm1ns_gemm1ks_strides
,
const
std
::
vector
<
ck
::
index_t
>&
acc0_bias_gs_ms_ns_lengths
,
const
std
::
vector
<
ck
::
index_t
>&
acc0_bias_gs_ms_ns_strides
,
const
std
::
vector
<
ck
::
index_t
>&
...
...
@@ -1420,6 +1501,10 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
c_gs_ms_gemm1ns_lengths
,
// c_gs_ms_os_lengths
c_gs_ms_gemm1ns_strides
,
// c_gs_ms_os_strides
lse_gs_ms_lengths
,
bgrad_gs_ns_ks_lengths
,
bgrad_gs_ns_ks_strides
,
b1grad_gs_gemm1ns_gemm1ks_lengths
,
b1grad_gs_gemm1ns_gemm1ks_strides
,
acc0_bias_gs_ms_ns_lengths
,
acc0_bias_gs_ms_ns_strides
,
acc1_bias_gs_ms_gemm1ns_lengths
,
// acc1_bias_gs_ms_os_lengths
...
...
@@ -1464,6 +1549,10 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
const
std
::
vector
<
index_t
>&
c_gs_ms_gemm1ns_lengths
,
// c_gs_ms_os_lengths
const
std
::
vector
<
index_t
>&
c_gs_ms_gemm1ns_strides
,
// c_gs_ms_os_strides
const
std
::
vector
<
index_t
>&
lse_gs_ms_lengths
,
const
std
::
vector
<
index_t
>&
bgrad_gs_ns_ks_lengths
,
const
std
::
vector
<
index_t
>&
bgrad_gs_ns_ks_strides
,
const
std
::
vector
<
index_t
>&
b1grad_gs_gemm1ns_gemm1ks_lengths
,
const
std
::
vector
<
index_t
>&
b1grad_gs_gemm1ns_gemm1ks_strides
,
const
std
::
vector
<
ck
::
index_t
>&
acc0_bias_gs_ms_ns_lengths
,
const
std
::
vector
<
ck
::
index_t
>&
acc0_bias_gs_ms_ns_strides
,
const
std
::
vector
<
ck
::
index_t
>&
...
...
@@ -1505,6 +1594,10 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
c_gs_ms_gemm1ns_lengths
,
// c_gs_ms_os_lengths
c_gs_ms_gemm1ns_strides
,
// c_gs_ms_os_strides
lse_gs_ms_lengths
,
bgrad_gs_ns_ks_lengths
,
bgrad_gs_ns_ks_strides
,
b1grad_gs_gemm1ns_gemm1ks_lengths
,
b1grad_gs_gemm1ns_gemm1ks_strides
,
acc0_bias_gs_ms_ns_lengths
,
acc0_bias_gs_ms_ns_strides
,
acc1_bias_gs_ms_gemm1ns_lengths
,
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_bwd_xdl_cshuffle_qloop_v1.hpp
View file @
e87ddb0e
...
...
@@ -74,16 +74,19 @@ __global__ void
const
CElementwiseOperation
c_element_op
,
const
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1
,
const
BGridDesc_BK0_N_BK1
b_grid_desc_bk0_n_bk1
,
const
BGridDesc_BK0_N_BK1
bgrad_grid_desc_bk0_n_bk1
,
const
D0GridDescriptor_M0_N0_M1_M2_N1_M3
d0_grid_desc_m0_n0_m1_m2_n1_m3
,
const
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
,
const
B1GridDesc_BK0_N_BK1
b1_grid_desc_bk0_n_bk1
,
const
B1GridDesc_BK0_N_BK1
b1grad_grid_desc_bk0_n_bk1
,
const
YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock
,
const
LSEGridDescriptor_M
lse_grid_desc_m
,
const
YGradGridDesc_O0_M_O1
ygrad_grid_desc_o0_m_o1
,
const
Block2CTileMap
block_2_ctile_map
,
const
index_t
batch_count
,
const
index_t
h_ratio
,
const
index_t
nblock
,
const
ComputeBasePtrOfStridedBatch
compute_base_ptr_of_batch
,
const
C0MatrixMask
c0_matrix_mask
,
...
...
@@ -99,21 +102,26 @@ __global__ void
const
index_t
num_blocks_per_batch
=
__builtin_amdgcn_readfirstlane
(
get_grid_size
()
/
batch_count
);
const
index_t
g_idx
=
__builtin_amdgcn_readfirstlane
(
get_block_1d_id
()
/
num_blocks_per_batch
);
const
index_t
gkv_idx
=
__builtin_amdgcn_readfirstlane
(
g_idx
/
h_ratio
);
// NOTE: assumes QKVY has the same layout as dQ/dK/dV/dY therefore being able to reuse batch
// offsets
const
long_index_t
a_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_base_ptr_of_batch
.
GetABasePtr
(
g_idx
)));
const
long_index_t
b_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_base_ptr_of_batch
.
GetBBasePtr
(
g_idx
)));
static_cast
<
long_index_t
>
(
compute_base_ptr_of_batch
.
GetBBasePtr
(
g
kv
_idx
)));
const
long_index_t
z_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_base_ptr_of_batch
.
GetZBasePtr
(
g_idx
)));
const
long_index_t
b1_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_base_ptr_of_batch
.
GetB1BasePtr
(
g_idx
)));
static_cast
<
long_index_t
>
(
compute_base_ptr_of_batch
.
GetB1BasePtr
(
g
kv
_idx
)));
const
long_index_t
c_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_base_ptr_of_batch
.
GetCBasePtr
(
g_idx
)));
const
long_index_t
lse_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_base_ptr_of_batch
.
GetLSEBasePtr
(
g_idx
)));
const
long_index_t
bgrad_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_base_ptr_of_batch
.
GetBGradBasePtr
(
g_idx
)));
const
long_index_t
b1grad_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_base_ptr_of_batch
.
GetB1GradBasePtr
(
g_idx
)));
ck
::
philox
ph
(
seed
,
0
,
offset
);
ZDataType
*
z_matrix_ptr
=
(
p_z_grid
==
nullptr
?
nullptr
:
p_z_grid
+
z_batch_offset
);
...
...
@@ -149,9 +157,9 @@ __global__ void
p_lse_grid
+
lse_batch_offset
,
p_ygrad_grid
+
c_batch_offset
,
p_qgrad_grid
+
a_batch_offset
,
p_kgrad_grid
+
b_batch_offset
,
p_kgrad_grid
+
b
grad
_batch_offset
,
tmp_p_d0grad_grid
,
p_vgrad_grid
+
b1_batch_offset
,
p_vgrad_grid
+
b1
grad
_batch_offset
,
p_shared
,
a_element_op
,
b_element_op
,
...
...
@@ -160,9 +168,11 @@ __global__ void
c_element_op
,
a_grid_desc_ak0_m_ak1
,
b_grid_desc_bk0_n_bk1
,
bgrad_grid_desc_bk0_n_bk1
,
d0_grid_desc_m0_n0_m1_m2_n1_m3
,
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
,
b1_grid_desc_bk0_n_bk1
,
b1grad_grid_desc_bk0_n_bk1
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
lse_grid_desc_m
,
ygrad_grid_desc_o0_m_o1
,
...
...
@@ -187,9 +197,9 @@ __global__ void
p_lse_grid
+
lse_batch_offset
,
p_ygrad_grid
+
c_batch_offset
,
p_qgrad_grid
+
a_batch_offset
,
p_kgrad_grid
+
b_batch_offset
,
p_kgrad_grid
+
b
grad
_batch_offset
,
tmp_p_d0grad_grid
,
p_vgrad_grid
+
b1_batch_offset
,
p_vgrad_grid
+
b1
grad
_batch_offset
,
p_shared
,
a_element_op
,
b_element_op
,
...
...
@@ -198,9 +208,11 @@ __global__ void
c_element_op
,
a_grid_desc_ak0_m_ak1
,
b_grid_desc_bk0_n_bk1
,
bgrad_grid_desc_bk0_n_bk1
,
d0_grid_desc_m0_n0_m1_m2_n1_m3
,
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
,
b1_grid_desc_bk0_n_bk1
,
b1grad_grid_desc_bk0_n_bk1
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
lse_grid_desc_m
,
ygrad_grid_desc_o0_m_o1
,
...
...
@@ -232,14 +244,17 @@ __global__ void
ignore
=
c_element_op
;
ignore
=
a_grid_desc_ak0_m_ak1
;
ignore
=
b_grid_desc_bk0_n_bk1
;
ignore
=
bgrad_grid_desc_bk0_n_bk1
;
ignore
=
d0_grid_desc_m0_n0_m1_m2_n1_m3
;
ignore
=
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
;
ignore
=
b1_grid_desc_bk0_n_bk1
;
ignore
=
b1grad_grid_desc_bk0_n_bk1
;
ignore
=
c_grid_desc_mblock_mperblock_nblock_nperblock
;
ignore
=
lse_grid_desc_m
;
ignore
=
ygrad_grid_desc_o0_m_o1
;
ignore
=
block_2_ctile_map
;
ignore
=
batch_count
;
ignore
=
h_ratio
;
ignore
=
nblock
;
ignore
=
compute_base_ptr_of_batch
;
ignore
=
c0_matrix_mask
;
...
...
@@ -603,6 +618,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
const
ZGridDesc_G_M_N
&
z_grid_desc_g_m_n
,
const
B1GridDesc_G_N_K
&
b1_grid_desc_g_n_k
,
const
CGridDesc_G_M_N
&
c_grid_desc_g_m_n
,
const
BGridDesc_G_N_K
&
bgrad_grid_desc_g_n_k
,
const
B1GridDesc_G_N_K
&
b1grad_grid_desc_g_n_k
,
index_t
BatchStrideLSE
)
:
a_grid_desc_g_m_k_
(
a_grid_desc_g_m_k
),
b_grid_desc_g_n_k_
(
b_grid_desc_g_n_k
),
...
...
@@ -610,6 +627,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
z_grid_desc_g_m_n_
(
z_grid_desc_g_m_n
),
b1_grid_desc_g_n_k_
(
b1_grid_desc_g_n_k
),
c_grid_desc_g_m_n_
(
c_grid_desc_g_m_n
),
bgrad_grid_desc_g_n_k_
(
bgrad_grid_desc_g_n_k
),
b1grad_grid_desc_g_n_k_
(
b1grad_grid_desc_g_n_k
),
BatchStrideLSE_
(
BatchStrideLSE
)
{
}
...
...
@@ -649,6 +668,16 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
return
g_idx
*
static_cast
<
long_index_t
>
(
BatchStrideLSE_
);
}
__host__
__device__
constexpr
long_index_t
GetBGradBasePtr
(
index_t
g_idx
)
const
{
return
bgrad_grid_desc_g_n_k_
.
CalculateOffset
(
make_multi_index
(
g_idx
,
0
,
0
));
}
__host__
__device__
constexpr
long_index_t
GetB1GradBasePtr
(
index_t
g_idx
)
const
{
return
b1grad_grid_desc_g_n_k_
.
CalculateOffset
(
make_multi_index
(
g_idx
,
0
,
0
));
}
private:
AGridDesc_G_M_K
a_grid_desc_g_m_k_
;
BGridDesc_G_N_K
b_grid_desc_g_n_k_
;
...
...
@@ -656,6 +685,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
ZGridDesc_G_M_N
z_grid_desc_g_m_n_
;
B1GridDesc_G_N_K
b1_grid_desc_g_n_k_
;
CGridDesc_G_M_N
c_grid_desc_g_m_n_
;
BGridDesc_G_N_K
bgrad_grid_desc_g_n_k_
;
B1GridDesc_G_N_K
b1grad_grid_desc_g_n_k_
;
index_t
BatchStrideLSE_
;
};
...
...
@@ -755,6 +786,10 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
const
std
::
vector
<
index_t
>&
c_gs_ms_gemm1ns_lengths
,
// c_gs_ms_os_lengths
const
std
::
vector
<
index_t
>&
c_gs_ms_gemm1ns_strides
,
// c_gs_ms_os_strides
const
std
::
vector
<
index_t
>&
lse_gs_ms_lengths
,
const
std
::
vector
<
index_t
>&
bgrad_gs_ns_ks_lengths
,
const
std
::
vector
<
index_t
>&
bgrad_gs_ns_ks_strides
,
const
std
::
vector
<
index_t
>&
b1grad_gs_gemm1ns_gemm1ks_lengths
,
const
std
::
vector
<
index_t
>&
b1grad_gs_gemm1ns_gemm1ks_strides
,
const
std
::
vector
<
ck
::
index_t
>&
acc0_bias_gs_ms_ns_lengths
,
const
std
::
vector
<
ck
::
index_t
>&
acc0_bias_gs_ms_ns_strides
,
const
std
::
vector
<
ck
::
index_t
>&
...
...
@@ -784,9 +819,13 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
DeviceOp
::
MakeAGridDescriptor_AK0_M_AK1
(
a_gs_ms_ks_lengths
,
a_gs_ms_ks_strides
)},
b_grid_desc_bk0_n_bk1_
{
DeviceOp
::
MakeBGridDescriptor_BK0_N_BK1
(
b_gs_ns_ks_lengths
,
b_gs_ns_ks_strides
)},
bgrad_grid_desc_bk0_n_bk1_
{
DeviceOp
::
MakeBGridDescriptor_BK0_N_BK1
(
bgrad_gs_ns_ks_lengths
,
bgrad_gs_ns_ks_strides
)},
z_grid_desc_m_n_
{
MakeZGridDescriptor_M_N
(
z_gs_ms_ns_lengths
,
z_gs_ms_ns_strides
)},
b1_grid_desc_bk0_n_bk1_
{
DeviceOp
::
MakeVGridDescriptor_O0_N_O1
(
b1_gs_gemm1ns_gemm1ks_lengths
,
b1_gs_gemm1ns_gemm1ks_strides
)},
b1grad_grid_desc_bk0_n_bk1_
{
DeviceOp
::
MakeVGridDescriptor_O0_N_O1
(
b1grad_gs_gemm1ns_gemm1ks_lengths
,
b1grad_gs_gemm1ns_gemm1ks_strides
)},
y_grid_desc_m_o_
{
Transform
::
MakeCGridDescriptor_M_N
(
c_gs_ms_gemm1ns_lengths
,
c_gs_ms_gemm1ns_strides
)},
lse_grid_desc_m_
{
DeviceOp
::
MakeLSEGridDescriptor_M
(
lse_gs_ms_lengths
[
NumDimG
])},
...
...
@@ -805,6 +844,10 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
c_gs_ms_gemm1ns_strides
)},
z_grid_desc_g_m_n_
{
Transform
::
MakeC0GridDescriptor_G_M_N
(
z_gs_ms_ns_lengths
,
z_gs_ms_ns_strides
)},
bgrad_grid_desc_g_n_k_
{
Transform
::
MakeB0GridDescriptor_G_N_K
(
bgrad_gs_ns_ks_lengths
,
bgrad_gs_ns_ks_strides
)},
b1grad_grid_desc_g_n_k_
{
Transform
::
MakeB1GridDescriptor_G_N_K
(
b1grad_gs_gemm1ns_gemm1ks_lengths
,
b1grad_gs_gemm1ns_gemm1ks_strides
)},
y_grid_desc_mblock_mperblock_oblock_operblock_
{},
block_2_ctile_map_
{
GridwiseGemm
::
MakeDefaultBlock2CTileMap
(
k_grid_desc_n_k_
)},
a_element_op_
{
a_element_op
},
...
...
@@ -826,6 +869,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
c_mz_gemm1nz_strides_
{
c_gs_ms_gemm1ns_strides
[
NumDimG
+
NumDimM
-
1
],
c_gs_ms_gemm1ns_strides
[
NumDimG
+
NumDimM
+
NumDimO
-
1
]},
batch_count_
{
c_grid_desc_g_m_n_
.
GetLength
(
I0
)},
h_ratio_
{
c_grid_desc_g_m_n_
.
GetLength
(
I0
)
/
b_grid_desc_g_n_k_
.
GetLength
(
I0
)},
p_drop_
{
p_drop
}
{
// TODO: implement bias addition
...
...
@@ -864,6 +908,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
z_grid_desc_g_m_n_
,
b1_grid_desc_g_n_k_
,
c_grid_desc_g_m_n_
,
bgrad_grid_desc_g_n_k_
,
b1grad_grid_desc_g_n_k_
,
type_convert
<
index_t
>
(
lse_grid_desc_m_
.
GetElementSpaceSize
()));
seed_
=
std
::
get
<
0
>
(
seeds
);
...
...
@@ -887,7 +933,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
<<
b_grid_desc_g_n_k_
.
GetLength
(
I1
)
<<
", "
<<
b_grid_desc_g_n_k_
.
GetLength
(
I2
)
<<
'\n'
;
// b_grid_desc_g_n_k_.Print();
std
::
cout
<<
"b1_grid_desc_g_
o_n
_: "
<<
b1_grid_desc_g_n_k_
.
GetLength
(
I0
)
<<
", "
std
::
cout
<<
"b1_grid_desc_g_
n_k
_: "
<<
b1_grid_desc_g_n_k_
.
GetLength
(
I0
)
<<
", "
<<
b1_grid_desc_g_n_k_
.
GetLength
(
I1
)
<<
", "
<<
b1_grid_desc_g_n_k_
.
GetLength
(
I2
)
<<
'\n'
;
// b1_grid_desc_g_n_k_.Print();
...
...
@@ -900,6 +946,17 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
std
::
cout
<<
"ygrad_grid_desc_o0_m_o1_: "
<<
ygrad_grid_desc_o0_m_o1_
.
GetLength
(
I0
)
<<
", "
<<
ygrad_grid_desc_o0_m_o1_
.
GetLength
(
I1
)
<<
", "
<<
ygrad_grid_desc_o0_m_o1_
.
GetLength
(
I2
)
<<
'\n'
;
std
::
cout
<<
"d0_grid_desc_g_m_n_: "
<<
d0_grid_desc_g_m_n_
.
GetLength
(
I0
)
<<
", "
<<
d0_grid_desc_g_m_n_
.
GetLength
(
I1
)
<<
", "
<<
d0_grid_desc_g_m_n_
.
GetLength
(
I2
)
<<
'\n'
;
std
::
cout
<<
"bgrad_grid_desc_g_n_k_: "
<<
bgrad_grid_desc_g_n_k_
.
GetLength
(
I0
)
<<
", "
<<
bgrad_grid_desc_g_n_k_
.
GetLength
(
I1
)
<<
", "
<<
bgrad_grid_desc_g_n_k_
.
GetLength
(
I2
)
<<
'\n'
;
// bgrad_grid_desc_g_n_k_.Print();
std
::
cout
<<
"b1grad_grid_desc_g_n_k_: "
<<
b1grad_grid_desc_g_n_k_
.
GetLength
(
I0
)
<<
", "
<<
b1grad_grid_desc_g_n_k_
.
GetLength
(
I1
)
<<
", "
<<
b1grad_grid_desc_g_n_k_
.
GetLength
(
I2
)
<<
'\n'
;
// b1grad_grid_desc_g_n_k_.Print();
}
// pointers
...
...
@@ -919,9 +976,11 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
// tensor descriptor
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1_
;
BGridDesc_BK0_N_BK1
b_grid_desc_bk0_n_bk1_
;
BGridDesc_BK0_N_BK1
bgrad_grid_desc_bk0_n_bk1_
;
typename
GridwiseGemm
::
D0GridDescriptor_M0_N0_M1_M2_N1_M3
d0_grid_desc_m0_n0_m1_m2_n1_m3_
;
ZGridDesc_M_N
z_grid_desc_m_n_
;
B1GridDesc_BK0_N_BK1
b1_grid_desc_bk0_n_bk1_
;
B1GridDesc_BK0_N_BK1
b1grad_grid_desc_bk0_n_bk1_
;
YGridDesc_M_O
y_grid_desc_m_o_
;
LSEGridDesc_M
lse_grid_desc_m_
;
KGridDesc_N_K
k_grid_desc_n_k_
;
...
...
@@ -934,6 +993,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
B1GridDesc_G_N_K
b1_grid_desc_g_n_k_
;
CGridDesc_G_M_N
c_grid_desc_g_m_n_
;
ZGridDesc_G_M_N
z_grid_desc_g_m_n_
;
BGridDesc_G_N_K
bgrad_grid_desc_g_n_k_
;
B1GridDesc_G_N_K
b1grad_grid_desc_g_n_k_
;
typename
GridwiseGemm
::
YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock
y_grid_desc_mblock_mperblock_oblock_operblock_
;
...
...
@@ -961,6 +1022,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
std
::
vector
<
index_t
>
c_mz_gemm1nz_strides_
;
index_t
batch_count_
;
index_t
h_ratio_
;
ComputeBasePtrOfStridedBatch
compute_base_ptr_of_batch_
;
float
p_drop_
;
...
...
@@ -1047,14 +1109,17 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
arg
.
c_element_op_
,
arg
.
a_grid_desc_ak0_m_ak1_
,
arg
.
b_grid_desc_bk0_n_bk1_
,
arg
.
bgrad_grid_desc_bk0_n_bk1_
,
arg
.
d0_grid_desc_m0_n0_m1_m2_n1_m3_
,
arg
.
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3_
,
arg
.
b1_grid_desc_bk0_n_bk1_
,
arg
.
b1grad_grid_desc_bk0_n_bk1_
,
arg
.
y_grid_desc_mblock_mperblock_oblock_operblock_
,
arg
.
lse_grid_desc_m_
,
arg
.
ygrad_grid_desc_o0_m_o1_
,
arg
.
block_2_ctile_map_
,
arg
.
batch_count_
,
arg
.
h_ratio_
,
arg
.
block_2_ctile_map_
.
CalculateGridSize
(
arg
.
k_grid_desc_n_k_
),
arg
.
compute_base_ptr_of_batch_
,
arg
.
c0_matrix_mask_
,
...
...
@@ -1108,13 +1173,14 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
// Check if C permute dimension matches GEMM + GEMM shape
const
index_t
c_g
=
arg
.
c_grid_desc_g_m_n_
.
GetLength
(
I0
);
// unpadded
const
index_t
b_g
=
arg
.
b_grid_desc_g_n_k_
.
GetLength
(
I0
);
const
index_t
c_m
=
arg
.
y_grid_desc_m_o_
.
GetLength
(
I0
);
const
index_t
c_gemm1n
=
arg
.
y_grid_desc_m_o_
.
GetLength
(
I1
);
const
index_t
a_m
=
arg
.
a_grid_desc_ak0_m_ak1_
.
GetLength
(
I1
);
const
index_t
b1_gemm1n
=
arg
.
b1_grid_desc_bk0_n_bk1_
.
GetLength
(
I0
)
*
arg
.
b1_grid_desc_bk0_n_bk1_
.
GetLength
(
I2
);
if
(
!
(
c_g
==
arg
.
batch_count_
&&
c_m
==
a_m
&&
c_gemm1n
==
b1_gemm1n
))
if
(
!
(
c_g
==
arg
.
batch_count_
&&
c_m
==
a_m
&&
c_gemm1n
==
b1_gemm1n
&&
c_g
%
b_g
==
0
))
{
return
false
;
}
...
...
@@ -1152,6 +1218,17 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
return
false
;
}
// saving dQ data with atomic_add instruction, so KzRaw must be a multiple of 2
if
constexpr
(
is_same
<
OutputDataType
,
half_t
>::
value
||
is_same
<
OutputDataType
,
bhalf_t
>::
value
)
{
if
(
KzRaw
%
2
!=
0
)
{
std
::
cout
<<
"K_q must be a multiple of 2"
<<
std
::
endl
;
return
false
;
}
}
// Check vector load/store requirement
const
auto
a_stride_lowest
=
ABlockTransferSrcVectorDim
==
2
?
arg
.
a_mz_kz_strides_
[
1
]
:
arg
.
a_mz_kz_strides_
[
0
];
...
...
@@ -1203,6 +1280,10 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
const
std
::
vector
<
index_t
>&
c_gs_ms_gemm1ns_lengths
,
// c_gs_ms_os_lengths
const
std
::
vector
<
index_t
>&
c_gs_ms_gemm1ns_strides
,
// c_gs_ms_os_strides
const
std
::
vector
<
index_t
>&
lse_gs_ms_lengths
,
const
std
::
vector
<
index_t
>&
bgrad_gs_ns_ks_lengths
,
const
std
::
vector
<
index_t
>&
bgrad_gs_ns_ks_strides
,
const
std
::
vector
<
index_t
>&
b1grad_gs_gemm1ns_gemm1ks_lengths
,
const
std
::
vector
<
index_t
>&
b1grad_gs_gemm1ns_gemm1ks_strides
,
const
std
::
vector
<
ck
::
index_t
>&
acc0_bias_gs_ms_ns_lengths
,
const
std
::
vector
<
ck
::
index_t
>&
acc0_bias_gs_ms_ns_strides
,
const
std
::
vector
<
ck
::
index_t
>&
...
...
@@ -1242,6 +1323,10 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
c_gs_ms_gemm1ns_lengths
,
// c_gs_ms_os_lengths
c_gs_ms_gemm1ns_strides
,
// c_gs_ms_os_strides
lse_gs_ms_lengths
,
bgrad_gs_ns_ks_lengths
,
bgrad_gs_ns_ks_strides
,
b1grad_gs_gemm1ns_gemm1ks_lengths
,
b1grad_gs_gemm1ns_gemm1ks_strides
,
acc0_bias_gs_ms_ns_lengths
,
acc0_bias_gs_ms_ns_strides
,
acc1_bias_gs_ms_gemm1ns_lengths
,
// acc1_bias_gs_ms_os_lengths
...
...
@@ -1270,10 +1355,10 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
void
*
p_qgrad_grid
,
void
*
p_kgrad_grid
,
void
*
p_vgrad_grid
,
const
D0DataType
*
p_acc0_bias
,
const
D1DataType
*
p_acc1_bias
,
D0DataType
*
p_d0grad_grid
,
D1DataType
*
p_d1grad_grid
,
const
void
*
p_acc0_bias
,
const
void
*
p_acc1_bias
,
void
*
p_d0grad_grid
,
void
*
p_d1grad_grid
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_lengths
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_strides
,
const
std
::
vector
<
index_t
>&
b_gs_ns_ks_lengths
,
...
...
@@ -1285,6 +1370,10 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
const
std
::
vector
<
index_t
>&
c_gs_ms_gemm1ns_lengths
,
// c_gs_ms_os_lengths
const
std
::
vector
<
index_t
>&
c_gs_ms_gemm1ns_strides
,
// c_gs_ms_os_strides
const
std
::
vector
<
index_t
>&
lse_gs_ms_lengths
,
const
std
::
vector
<
index_t
>&
bgrad_gs_ns_ks_lengths
,
const
std
::
vector
<
index_t
>&
bgrad_gs_ns_ks_strides
,
const
std
::
vector
<
index_t
>&
b1grad_gs_gemm1ns_gemm1ks_lengths
,
const
std
::
vector
<
index_t
>&
b1grad_gs_gemm1ns_gemm1ks_strides
,
const
std
::
vector
<
ck
::
index_t
>&
acc0_bias_gs_ms_ns_lengths
,
const
std
::
vector
<
ck
::
index_t
>&
acc0_bias_gs_ms_ns_strides
,
const
std
::
vector
<
ck
::
index_t
>&
...
...
@@ -1312,8 +1401,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
static_cast
<
OutputDataType
*>
(
p_vgrad_grid
),
static_cast
<
const
D0DataType
*>
(
p_acc0_bias
),
// cast in struct Argument
static_cast
<
const
D1DataType
*>
(
p_acc1_bias
),
// cast in struct Argument
static_cast
<
const
D0DataType
*>
(
p_d0grad_grid
),
static_cast
<
const
D1DataType
*>
(
p_d1grad_grid
),
static_cast
<
D0DataType
*>
(
p_d0grad_grid
),
static_cast
<
D1DataType
*>
(
p_d1grad_grid
),
a_gs_ms_ks_lengths
,
a_gs_ms_ks_strides
,
b_gs_ns_ks_lengths
,
...
...
@@ -1325,6 +1414,10 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
c_gs_ms_gemm1ns_lengths
,
// c_gs_ms_os_lengths
c_gs_ms_gemm1ns_strides
,
// c_gs_ms_os_strides
lse_gs_ms_lengths
,
bgrad_gs_ns_ks_lengths
,
bgrad_gs_ns_ks_strides
,
b1grad_gs_gemm1ns_gemm1ks_lengths
,
b1grad_gs_gemm1ns_gemm1ks_strides
,
acc0_bias_gs_ms_ns_lengths
,
acc0_bias_gs_ms_ns_strides
,
acc1_bias_gs_ms_gemm1ns_lengths
,
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment