Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
10836d41
Commit
10836d41
authored
Oct 11, 2023
by
danyao12
Browse files
G1/G2 -> G1Q/G1KV
parent
9574b34d
Changes
12
Hide whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
812 additions
and
780 deletions
+812
-780
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_backward_v2.cpp
..._softmax_gemm/batched_multihead_attention_backward_v2.cpp
+69
-68
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_backward_v3.cpp
..._softmax_gemm/batched_multihead_attention_backward_v3.cpp
+69
-68
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_train_v2.cpp
...ale_softmax_gemm/batched_multihead_attention_train_v2.cpp
+74
-73
example/32_batched_gemm_scale_softmax_gemm/grouped_multihead_attention_backward_v2.cpp
..._softmax_gemm/grouped_multihead_attention_backward_v2.cpp
+70
-64
example/32_batched_gemm_scale_softmax_gemm/grouped_multihead_attention_backward_v3.cpp
..._softmax_gemm/grouped_multihead_attention_backward_v3.cpp
+70
-64
example/32_batched_gemm_scale_softmax_gemm/grouped_multihead_attention_train_v2.cpp
...ale_softmax_gemm/grouped_multihead_attention_train_v2.cpp
+74
-68
example/32_batched_gemm_scale_softmax_gemm/run_batched_multihead_attention_forward.inc
..._softmax_gemm/run_batched_multihead_attention_forward.inc
+50
-50
example/32_batched_gemm_scale_softmax_gemm/run_grouped_multihead_attention_forward.inc
..._softmax_gemm/run_grouped_multihead_attention_forward.inc
+61
-59
example/52_flash_atten_bias/batched_multihead_attention_bias_backward_v2.cpp
...ten_bias/batched_multihead_attention_bias_backward_v2.cpp
+77
-76
example/52_flash_atten_bias/grouped_multihead_attention_bias_backward_v2.cpp
...ten_bias/grouped_multihead_attention_bias_backward_v2.cpp
+77
-71
example/52_flash_atten_bias/run_batched_multihead_attention_bias_forward_v2.inc
..._bias/run_batched_multihead_attention_bias_forward_v2.inc
+55
-55
example/52_flash_atten_bias/run_grouped_multihead_attention_bias_forward_v2.inc
..._bias/run_grouped_multihead_attention_bias_forward_v2.inc
+66
-64
No files found.
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_backward_v2.cpp
View file @
10836d41
...
@@ -269,15 +269,15 @@ int run(int argc, char* argv[])
...
@@ -269,15 +269,15 @@ int run(int argc, char* argv[])
// Overall QKV matrices shape
// Overall QKV matrices shape
// y_g_m_o = Softmax(alpha * Q_g_m_k * K_g_k_n) * V_g_n_o
// y_g_m_o = Softmax(alpha * Q_g_m_k * K_g_k_n) * V_g_n_o
// y_g0_g1_m_o = reshape(y_g_m_o, [G0, G1, M, O])
// y_g0_g1
q
_m_o = reshape(y_g_m_o, [G0, G1
Q
, M, O])
// y_g0_m_g1_o = permute(y_g0_g1_m_o, [0, 2, 1, 3])
// y_g0_m_g1
q
_o = permute(y_g0_g1
q
_m_o, [0, 2, 1, 3])
ck
::
index_t
M
=
512
;
ck
::
index_t
M
=
512
;
ck
::
index_t
N
=
512
;
ck
::
index_t
N
=
512
;
ck
::
index_t
K
=
DIM
;
ck
::
index_t
K
=
DIM
;
ck
::
index_t
O
=
DIM
;
ck
::
index_t
O
=
DIM
;
ck
::
index_t
G0
=
4
;
ck
::
index_t
G0
=
4
;
ck
::
index_t
G1
=
6
;
// h_q
ck
::
index_t
G1
Q
=
6
;
// h_q
ck
::
index_t
G
2
=
6
;
// h_kv
ck
::
index_t
G
1KV
=
6
;
// h_kv
bool
input_permute
=
false
;
bool
input_permute
=
false
;
bool
output_permute
=
false
;
bool
output_permute
=
false
;
...
@@ -302,13 +302,13 @@ int run(int argc, char* argv[])
...
@@ -302,13 +302,13 @@ int run(int argc, char* argv[])
init_method
=
std
::
stoi
(
argv
[
2
]);
init_method
=
std
::
stoi
(
argv
[
2
]);
time_kernel
=
std
::
stoi
(
argv
[
3
]);
time_kernel
=
std
::
stoi
(
argv
[
3
]);
M
=
std
::
stoi
(
argv
[
4
]);
M
=
std
::
stoi
(
argv
[
4
]);
N
=
std
::
stoi
(
argv
[
5
]);
N
=
std
::
stoi
(
argv
[
5
]);
K
=
std
::
stoi
(
argv
[
6
]);
K
=
std
::
stoi
(
argv
[
6
]);
O
=
std
::
stoi
(
argv
[
7
]);
O
=
std
::
stoi
(
argv
[
7
]);
G0
=
std
::
stoi
(
argv
[
8
]);
G0
=
std
::
stoi
(
argv
[
8
]);
G1
=
std
::
stoi
(
argv
[
9
]);
G1
Q
=
std
::
stoi
(
argv
[
9
]);
G
2
=
std
::
stoi
(
argv
[
10
]);
G
1KV
=
std
::
stoi
(
argv
[
10
]);
p_drop
=
std
::
stof
(
argv
[
11
]);
p_drop
=
std
::
stof
(
argv
[
11
]);
...
@@ -320,7 +320,7 @@ int run(int argc, char* argv[])
...
@@ -320,7 +320,7 @@ int run(int argc, char* argv[])
printf
(
"arg1: verification (0=no, 1=yes)
\n
"
);
printf
(
"arg1: verification (0=no, 1=yes)
\n
"
);
printf
(
"arg2: initialization (0=no init, 1=integer value, 2=decimal value)
\n
"
);
printf
(
"arg2: initialization (0=no init, 1=integer value, 2=decimal value)
\n
"
);
printf
(
"arg3: time kernel (0=no, 1=yes)
\n
"
);
printf
(
"arg3: time kernel (0=no, 1=yes)
\n
"
);
printf
(
"arg4 to 10: M, N, K, O, G0, G1, G
2
\n
"
);
printf
(
"arg4 to 10: M, N, K, O, G0, G1
Q
, G
1KV
\n
"
);
printf
(
"arg11: p_drop
\n
"
);
printf
(
"arg11: p_drop
\n
"
);
printf
(
"arg12 to 13: input / output permute
\n
"
);
printf
(
"arg12 to 13: input / output permute
\n
"
);
exit
(
0
);
exit
(
0
);
...
@@ -339,8 +339,8 @@ int run(int argc, char* argv[])
...
@@ -339,8 +339,8 @@ int run(int argc, char* argv[])
std
::
cout
<<
"K: "
<<
K
<<
std
::
endl
;
std
::
cout
<<
"K: "
<<
K
<<
std
::
endl
;
std
::
cout
<<
"O: "
<<
O
<<
std
::
endl
;
std
::
cout
<<
"O: "
<<
O
<<
std
::
endl
;
std
::
cout
<<
"G0: "
<<
G0
<<
std
::
endl
;
std
::
cout
<<
"G0: "
<<
G0
<<
std
::
endl
;
std
::
cout
<<
"G1: "
<<
G1
<<
std
::
endl
;
std
::
cout
<<
"G1
Q
: "
<<
G1
Q
<<
std
::
endl
;
std
::
cout
<<
"G
2
: "
<<
G
2
<<
std
::
endl
;
std
::
cout
<<
"G
1KV
: "
<<
G
1KV
<<
std
::
endl
;
std
::
cout
<<
"alpha: "
<<
alpha
<<
std
::
endl
;
std
::
cout
<<
"alpha: "
<<
alpha
<<
std
::
endl
;
std
::
cout
<<
"input_permute: "
<<
input_permute
<<
std
::
endl
;
std
::
cout
<<
"input_permute: "
<<
input_permute
<<
std
::
endl
;
std
::
cout
<<
"output_permute: "
<<
output_permute
<<
std
::
endl
;
std
::
cout
<<
"output_permute: "
<<
output_permute
<<
std
::
endl
;
...
@@ -348,57 +348,57 @@ int run(int argc, char* argv[])
...
@@ -348,57 +348,57 @@ int run(int argc, char* argv[])
std
::
cout
<<
"seed: "
<<
seed
<<
std
::
endl
;
std
::
cout
<<
"seed: "
<<
seed
<<
std
::
endl
;
std
::
cout
<<
"offset: "
<<
offset
<<
std
::
endl
;
std
::
cout
<<
"offset: "
<<
offset
<<
std
::
endl
;
const
ck
::
index_t
BatchCount
=
G0
*
G1
;
const
ck
::
index_t
BatchCount
=
G0
*
G1
Q
;
std
::
vector
<
ck
::
index_t
>
q_gs_ms_ks_lengths
{
G0
,
G1
,
M
,
K
};
std
::
vector
<
ck
::
index_t
>
q_gs_ms_ks_lengths
{
G0
,
G1
Q
,
M
,
K
};
std
::
vector
<
ck
::
index_t
>
q_gs_ms_ks_strides
=
std
::
vector
<
ck
::
index_t
>
q_gs_ms_ks_strides
=
input_permute
input_permute
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
*
K
,
K
,
G1
*
K
,
1
}
// Q layout [G0, M, G1, K]
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
Q
*
K
,
K
,
G1
Q
*
K
,
1
}
// Q layout [G0, M, G1
Q
, K]
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
M
*
K
,
M
*
K
,
K
,
1
};
// Q layout [G0, G1, M, K]
:
std
::
vector
<
ck
::
index_t
>
{
G1
Q
*
M
*
K
,
M
*
K
,
K
,
1
};
// Q layout [G0, G1
Q
, M, K]
std
::
vector
<
ck
::
index_t
>
k_gs_ns_ks_lengths
{
G0
,
G
2
,
N
,
K
};
std
::
vector
<
ck
::
index_t
>
k_gs_ns_ks_lengths
{
G0
,
G
1KV
,
N
,
K
};
std
::
vector
<
ck
::
index_t
>
k_gs_ns_ks_strides
=
std
::
vector
<
ck
::
index_t
>
k_gs_ns_ks_strides
=
input_permute
input_permute
?
std
::
vector
<
ck
::
index_t
>
{
N
*
G
2
*
K
,
K
,
G
2
*
K
,
1
}
// K layout [G0, N, G
2
, K]
?
std
::
vector
<
ck
::
index_t
>
{
N
*
G
1KV
*
K
,
K
,
G
1KV
*
K
,
1
}
// K layout [G0, N, G
1KV
, K]
:
std
::
vector
<
ck
::
index_t
>
{
G
2
*
N
*
K
,
N
*
K
,
K
,
1
};
// K layout [G0, G
2
, N, K]
:
std
::
vector
<
ck
::
index_t
>
{
G
1KV
*
N
*
K
,
N
*
K
,
K
,
1
};
// K layout [G0, G
1KV
, N, K]
std
::
vector
<
ck
::
index_t
>
v_gs_os_ns_lengths
{
G0
,
G
2
,
O
,
N
};
std
::
vector
<
ck
::
index_t
>
v_gs_os_ns_lengths
{
G0
,
G
1KV
,
O
,
N
};
std
::
vector
<
ck
::
index_t
>
v_gs_os_ns_strides
=
std
::
vector
<
ck
::
index_t
>
v_gs_os_ns_strides
=
input_permute
input_permute
?
std
::
vector
<
ck
::
index_t
>
{
N
*
G
2
*
O
,
O
,
1
,
G
2
*
O
}
// V layout [G0, N, G
2
, O]
?
std
::
vector
<
ck
::
index_t
>
{
N
*
G
1KV
*
O
,
O
,
1
,
G
1KV
*
O
}
// V layout [G0, N, G
1KV
, O]
:
std
::
vector
<
ck
::
index_t
>
{
G
2
*
N
*
O
,
N
*
O
,
1
,
O
};
// V layout [G0, G
2
, N, O]
:
std
::
vector
<
ck
::
index_t
>
{
G
1KV
*
N
*
O
,
N
*
O
,
1
,
O
};
// V layout [G0, G
1KV
, N, O]
std
::
vector
<
ck
::
index_t
>
y_gs_ms_os_lengths
{
G0
,
G1
,
M
,
O
};
std
::
vector
<
ck
::
index_t
>
y_gs_ms_os_lengths
{
G0
,
G1
Q
,
M
,
O
};
std
::
vector
<
ck
::
index_t
>
y_gs_ms_os_strides
=
std
::
vector
<
ck
::
index_t
>
y_gs_ms_os_strides
=
output_permute
output_permute
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
*
O
,
O
,
G1
*
O
,
1
}
// Y layout [G0, M, G1, O]
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
Q
*
O
,
O
,
G1
Q
*
O
,
1
}
// Y layout [G0, M, G1
Q
, O]
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
M
*
O
,
M
*
O
,
O
,
1
};
// Y layout [G0, G1, M, O]
:
std
::
vector
<
ck
::
index_t
>
{
G1
Q
*
M
*
O
,
M
*
O
,
O
,
1
};
// Y layout [G0, G1
Q
, M, O]
std
::
vector
<
ck
::
index_t
>
z_gs_ms_ns_lengths
{
G0
,
G1
,
M
,
N
};
std
::
vector
<
ck
::
index_t
>
z_gs_ms_ns_lengths
{
G0
,
G1
Q
,
M
,
N
};
std
::
vector
<
ck
::
index_t
>
z_gs_ms_ns_strides
=
std
::
vector
<
ck
::
index_t
>
z_gs_ms_ns_strides
=
input_permute
input_permute
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
*
N
,
N
,
G1
*
N
,
1
}
// Z layout [G0, M, G1, N]
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
Q
*
N
,
N
,
G1
Q
*
N
,
1
}
// Z layout [G0, M, G1
Q
, N]
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
M
*
N
,
M
*
N
,
N
,
1
};
// Z layout [G0, G1, M, N]
:
std
::
vector
<
ck
::
index_t
>
{
G1
Q
*
M
*
N
,
M
*
N
,
N
,
1
};
// Z layout [G0, G1
Q
, M, N]
std
::
vector
<
ck
::
index_t
>
kgrad_gs_ns_ks_lengths
{
G0
,
G1
,
N
,
K
};
std
::
vector
<
ck
::
index_t
>
kgrad_gs_ns_ks_lengths
{
G0
,
G1
Q
,
N
,
K
};
std
::
vector
<
ck
::
index_t
>
kgrad_gs_ns_ks_strides
=
std
::
vector
<
ck
::
index_t
>
kgrad_gs_ns_ks_strides
=
input_permute
input_permute
?
std
::
vector
<
ck
::
index_t
>
{
N
*
G1
*
K
,
K
,
G1
*
K
,
1
}
// KGrad layout [G0, N, G1, K]
?
std
::
vector
<
ck
::
index_t
>
{
N
*
G1
Q
*
K
,
K
,
G1
Q
*
K
,
1
}
// KGrad layout [G0, N, G1
Q
, K]
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
N
*
K
,
N
*
K
,
K
,
1
};
// KGrad layout [G0, G1, N, K]
:
std
::
vector
<
ck
::
index_t
>
{
G1
Q
*
N
*
K
,
N
*
K
,
K
,
1
};
// KGrad layout [G0, G1
Q
, N, K]
std
::
vector
<
ck
::
index_t
>
vgrad_gs_os_ns_lengths
{
G0
,
G1
,
O
,
N
};
std
::
vector
<
ck
::
index_t
>
vgrad_gs_os_ns_lengths
{
G0
,
G1
Q
,
O
,
N
};
std
::
vector
<
ck
::
index_t
>
vgrad_gs_os_ns_strides
=
std
::
vector
<
ck
::
index_t
>
vgrad_gs_os_ns_strides
=
input_permute
input_permute
?
std
::
vector
<
ck
::
index_t
>
{
N
*
G1
*
O
,
O
,
1
,
G1
*
O
}
// VGrad layout [G0, N, G1, O]
?
std
::
vector
<
ck
::
index_t
>
{
N
*
G1
Q
*
O
,
O
,
1
,
G1
Q
*
O
}
// VGrad layout [G0, N, G1
Q
, O]
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
N
*
O
,
N
*
O
,
1
,
O
};
// VGrad layout [G0, G1, N, O]
:
std
::
vector
<
ck
::
index_t
>
{
G1
Q
*
N
*
O
,
N
*
O
,
1
,
O
};
// VGrad layout [G0, G1
Q
, N, O]
// The softmax stat log-sum-exp (LSE) is used to speed up softmax calculation in backward pass
// The softmax stat log-sum-exp (LSE) is used to speed up softmax calculation in backward pass
// Pi = exp(Si) / sum(exp(S0) + exp(S1) + ...)
// Pi = exp(Si) / sum(exp(S0) + exp(S1) + ...)
// = exp(Si) / exp(log(sum(exp() + ...)))
// = exp(Si) / exp(log(sum(exp() + ...)))
// = exp(Si - log(sum(exp() + ...)))
// = exp(Si - log(sum(exp() + ...)))
// ^^^^^^^^^^^^^^^^^^^^^
// ^^^^^^^^^^^^^^^^^^^^^
// LSE
// LSE
std
::
vector
<
ck
::
index_t
>
lse_gs_ms_lengths
{
G0
,
G1
,
M
};
std
::
vector
<
ck
::
index_t
>
lse_gs_ms_lengths
{
G0
,
G1
Q
,
M
};
std
::
vector
<
ck
::
index_t
>
lse_gs_ms_strides
{
G1
*
M
,
M
,
1
};
// LSE layout [G0, G1, M]
std
::
vector
<
ck
::
index_t
>
lse_gs_ms_strides
{
G1
Q
*
M
,
M
,
1
};
// LSE layout [G0, G1
Q
, M]
Tensor
<
InputDataType
>
q_gs_ms_ks
(
q_gs_ms_ks_lengths
,
q_gs_ms_ks_strides
);
Tensor
<
InputDataType
>
q_gs_ms_ks
(
q_gs_ms_ks_lengths
,
q_gs_ms_ks_strides
);
Tensor
<
InputDataType
>
k_gs_ns_ks
(
k_gs_ns_ks_lengths
,
k_gs_ns_ks_strides
);
Tensor
<
InputDataType
>
k_gs_ns_ks
(
k_gs_ns_ks_lengths
,
k_gs_ns_ks_strides
);
...
@@ -451,14 +451,14 @@ int run(int argc, char* argv[])
...
@@ -451,14 +451,14 @@ int run(int argc, char* argv[])
q_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_1
<
InputDataType
>
{
1
});
q_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_1
<
InputDataType
>
{
1
});
k_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
InputDataType
>
{});
k_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
InputDataType
>
{});
v_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
InputDataType
>
{});
v_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
InputDataType
>
{});
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_Sequential
<
2
>
{});
// dy[g0, g1, m, o]
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_Sequential
<
2
>
{});
// dy[g0, g1
q
, m, o]
// dO dot O = [0; 1; 2; ...]
// dO dot O = [0; 1; 2; ...]
break
;
break
;
case
6
:
case
6
:
q_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_1
<
InputDataType
>
{
1
});
q_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_1
<
InputDataType
>
{
1
});
k_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
InputDataType
>
{});
k_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
InputDataType
>
{});
v_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
InputDataType
>
{});
v_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
InputDataType
>
{});
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_Sequential
<
3
>
{});
// dy[g0, g1, m, o]
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_Sequential
<
3
>
{});
// dy[g0, g1
q
, m, o]
// assume mnko = 256
// assume mnko = 256
// P = softmax(QK) = 0.0039 * ones
// P = softmax(QK) = 0.0039 * ones
// O = P V = 0.0039 * ones
// O = P V = 0.0039 * ones
...
@@ -471,7 +471,8 @@ int run(int argc, char* argv[])
...
@@ -471,7 +471,8 @@ int run(int argc, char* argv[])
q_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_1
<
InputDataType
>
{
1
});
q_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_1
<
InputDataType
>
{
1
});
k_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
InputDataType
>
{});
k_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
InputDataType
>
{});
v_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
InputDataType
>
{});
v_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
InputDataType
>
{});
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_1
<
InputDataType
>
{
1
});
// dy[g0, g1, m, o]
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_1
<
InputDataType
>
{
1
});
// dy[g0, g1q, m, o]
// assume mnko = 256
// assume mnko = 256
// P = softmax(QK) = 0.0039 * ones
// P = softmax(QK) = 0.0039 * ones
// O = P V = 0.0039 * ones
// O = P V = 0.0039 * ones
...
@@ -493,20 +494,20 @@ int run(int argc, char* argv[])
...
@@ -493,20 +494,20 @@ int run(int argc, char* argv[])
Tensor
<
LSEDataType
>
lse_g_m
({
BatchCount
,
M
});
Tensor
<
LSEDataType
>
lse_g_m
({
BatchCount
,
M
});
q_gs_ms_ks
.
ForEach
(
q_gs_ms_ks
.
ForEach
(
[
&
](
auto
&
self
,
auto
idx
)
{
q_g_m_k
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
});
[
&
](
auto
&
self
,
auto
idx
)
{
q_g_m_k
(
idx
[
0
]
*
G1
Q
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
});
k_g_n_k
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
k_g_n_k
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
const
size_t
&
g0
=
idx
[
0
]
/
G1
;
const
size_t
&
g0
=
idx
[
0
]
/
G1
Q
;
const
size_t
&
g1
=
idx
[
0
]
%
G1
;
const
size_t
&
g1
q
=
idx
[
0
]
%
G1
Q
;
const
size_t
&
g
2
=
g1
/
(
G1
/
G
2
);
const
size_t
&
g
1kv
=
g1
q
/
(
G1
Q
/
G
1KV
);
self
(
idx
)
=
k_gs_ns_ks
(
g0
,
g
2
,
idx
[
1
],
idx
[
2
]);
self
(
idx
)
=
k_gs_ns_ks
(
g0
,
g
1kv
,
idx
[
1
],
idx
[
2
]);
});
});
v_g_n_o
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
v_g_n_o
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
const
size_t
&
g0
=
idx
[
0
]
/
G1
;
const
size_t
&
g0
=
idx
[
0
]
/
G1
Q
;
const
size_t
&
g1
=
idx
[
0
]
%
G1
;
const
size_t
&
g1
q
=
idx
[
0
]
%
G1
Q
;
const
size_t
&
g
2
=
g1
/
(
G1
/
G
2
);
const
size_t
&
g
1kv
=
g1
q
/
(
G1
Q
/
G
1KV
);
self
(
idx
)
=
v_gs_os_ns
(
g0
,
g
2
,
idx
[
2
],
idx
[
1
]);
self
(
idx
)
=
v_gs_os_ns
(
g0
,
g
1kv
,
idx
[
2
],
idx
[
1
]);
});
});
// qkv gradients have the same descriptor as with qkv
// qkv gradients have the same descriptor as with qkv
...
@@ -651,7 +652,7 @@ int run(int argc, char* argv[])
...
@@ -651,7 +652,7 @@ int run(int argc, char* argv[])
// copy z matirx data form device
// copy z matirx data form device
z_device_buf
.
FromDevice
(
z_gs_ms_ns
.
mData
.
data
());
z_device_buf
.
FromDevice
(
z_gs_ms_ns
.
mData
.
data
());
z_gs_ms_ns
.
ForEach
(
z_gs_ms_ns
.
ForEach
(
[
&
](
auto
&
self
,
auto
idx
)
{
z_g_m_n
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
});
[
&
](
auto
&
self
,
auto
idx
)
{
z_g_m_n
(
idx
[
0
]
*
G1
Q
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
});
// std::cout << "z_g_m_n ref:\n" << z_g_m_n;
// std::cout << "z_g_m_n ref:\n" << z_g_m_n;
bool
pass
=
true
;
bool
pass
=
true
;
...
@@ -671,10 +672,10 @@ int run(int argc, char* argv[])
...
@@ -671,10 +672,10 @@ int run(int argc, char* argv[])
p_dropout_in_uint8_t
,
p_dropout_in_uint8_t
,
rp_dropout
);
rp_dropout
);
y_gs_ms_os
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
y_gs_ms_os
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
self
(
idx
)
=
y_g_m_o
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
]);
self
(
idx
)
=
y_g_m_o
(
idx
[
0
]
*
G1
Q
+
idx
[
1
],
idx
[
2
],
idx
[
3
]);
});
});
lse_gs_ms
.
ForEach
(
lse_gs_ms
.
ForEach
(
[
&
](
auto
&
self
,
auto
idx
)
{
self
(
idx
)
=
lse_g_m
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
]);
});
[
&
](
auto
&
self
,
auto
idx
)
{
self
(
idx
)
=
lse_g_m
(
idx
[
0
]
*
G1
Q
+
idx
[
1
],
idx
[
2
]);
});
y_device_buf
.
ToDevice
(
y_gs_ms_os
.
mData
.
data
());
y_device_buf
.
ToDevice
(
y_gs_ms_os
.
mData
.
data
());
lse_device_buf
.
ToDevice
(
lse_gs_ms
.
mData
.
data
());
lse_device_buf
.
ToDevice
(
lse_gs_ms
.
mData
.
data
());
...
@@ -692,7 +693,7 @@ int run(int argc, char* argv[])
...
@@ -692,7 +693,7 @@ int run(int argc, char* argv[])
Tensor
<
InputDataType
>
ygrad_dot_y_g_m
({
BatchCount
,
M
});
Tensor
<
InputDataType
>
ygrad_dot_y_g_m
({
BatchCount
,
M
});
ygrad_gs_ms_os
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
ygrad_gs_ms_os
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
ygrad_g_m_o
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
ygrad_g_m_o
(
idx
[
0
]
*
G1
Q
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
});
});
#if PRINT_HOST
#if PRINT_HOST
...
@@ -811,26 +812,26 @@ int run(int argc, char* argv[])
...
@@ -811,26 +812,26 @@ int run(int argc, char* argv[])
// permute
// permute
qgrad_gs_ms_ks_host_result
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
qgrad_gs_ms_ks_host_result
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
const
size_t
&
g0
=
idx
[
0
];
const
size_t
&
g0
=
idx
[
0
];
const
size_t
&
g1
=
idx
[
1
];
const
size_t
&
g1
q
=
idx
[
1
];
const
size_t
g
=
g0
*
G1
+
g1
;
const
size_t
g
=
g0
*
G1
Q
+
g1
q
;
self
(
idx
)
=
qgrad_g_m_k
(
g
,
idx
[
2
],
idx
[
3
]);
self
(
idx
)
=
qgrad_g_m_k
(
g
,
idx
[
2
],
idx
[
3
]);
});
});
kgrad_gs_ns_ks_host_result
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
kgrad_gs_ns_ks_host_result
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
const
size_t
&
g0
=
idx
[
0
];
const
size_t
&
g0
=
idx
[
0
];
const
size_t
&
g1
=
idx
[
1
];
const
size_t
&
g1
q
=
idx
[
1
];
const
size_t
g
=
g0
*
G1
+
g1
;
const
size_t
g
=
g0
*
G1
Q
+
g1
q
;
self
(
idx
)
=
kgrad_g_n_k
(
g
,
idx
[
2
],
idx
[
3
]);
self
(
idx
)
=
kgrad_g_n_k
(
g
,
idx
[
2
],
idx
[
3
]);
});
});
vgrad_gs_os_ns_host_result
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
vgrad_gs_os_ns_host_result
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
const
size_t
&
g0
=
idx
[
0
];
const
size_t
&
g0
=
idx
[
0
];
const
size_t
&
g1
=
idx
[
1
];
const
size_t
&
g1
q
=
idx
[
1
];
const
size_t
g
=
g0
*
G1
+
g1
;
const
size_t
g
=
g0
*
G1
Q
+
g1
q
;
self
(
idx
)
=
vgrad_g_n_o
(
g
,
idx
[
3
],
idx
[
2
]);
self
(
idx
)
=
vgrad_g_n_o
(
g
,
idx
[
3
],
idx
[
2
]);
});
});
...
...
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_backward_v3.cpp
View file @
10836d41
...
@@ -270,15 +270,15 @@ int run(int argc, char* argv[])
...
@@ -270,15 +270,15 @@ int run(int argc, char* argv[])
// Overall QKV matrices shape
// Overall QKV matrices shape
// y_g_m_o = Softmax(alpha * Q_g_m_k * K_g_k_n) * V_g_n_o
// y_g_m_o = Softmax(alpha * Q_g_m_k * K_g_k_n) * V_g_n_o
// y_g0_g1_m_o = reshape(y_g_m_o, [G0, G1, M, O])
// y_g0_g1
q
_m_o = reshape(y_g_m_o, [G0, G1
Q
, M, O])
// y_g0_m_g1_o = permute(y_g0_g1_m_o, [0, 2, 1, 3])
// y_g0_m_g1
q
_o = permute(y_g0_g1
q
_m_o, [0, 2, 1, 3])
ck
::
index_t
M
=
512
;
ck
::
index_t
M
=
512
;
ck
::
index_t
N
=
512
;
ck
::
index_t
N
=
512
;
ck
::
index_t
K
=
DIM
;
ck
::
index_t
K
=
DIM
;
ck
::
index_t
O
=
DIM
;
ck
::
index_t
O
=
DIM
;
ck
::
index_t
G0
=
4
;
ck
::
index_t
G0
=
4
;
ck
::
index_t
G1
=
6
;
// h_q
ck
::
index_t
G1
Q
=
6
;
// h_q
ck
::
index_t
G
2
=
6
;
// h_kv
ck
::
index_t
G
1KV
=
6
;
// h_kv
bool
input_permute
=
false
;
bool
input_permute
=
false
;
bool
output_permute
=
false
;
bool
output_permute
=
false
;
...
@@ -303,13 +303,13 @@ int run(int argc, char* argv[])
...
@@ -303,13 +303,13 @@ int run(int argc, char* argv[])
init_method
=
std
::
stoi
(
argv
[
2
]);
init_method
=
std
::
stoi
(
argv
[
2
]);
time_kernel
=
std
::
stoi
(
argv
[
3
]);
time_kernel
=
std
::
stoi
(
argv
[
3
]);
M
=
std
::
stoi
(
argv
[
4
]);
M
=
std
::
stoi
(
argv
[
4
]);
N
=
std
::
stoi
(
argv
[
5
]);
N
=
std
::
stoi
(
argv
[
5
]);
K
=
std
::
stoi
(
argv
[
6
]);
K
=
std
::
stoi
(
argv
[
6
]);
O
=
std
::
stoi
(
argv
[
7
]);
O
=
std
::
stoi
(
argv
[
7
]);
G0
=
std
::
stoi
(
argv
[
8
]);
G0
=
std
::
stoi
(
argv
[
8
]);
G1
=
std
::
stoi
(
argv
[
9
]);
G1
Q
=
std
::
stoi
(
argv
[
9
]);
G
2
=
std
::
stoi
(
argv
[
10
]);
G
1KV
=
std
::
stoi
(
argv
[
10
]);
p_drop
=
std
::
stof
(
argv
[
11
]);
p_drop
=
std
::
stof
(
argv
[
11
]);
...
@@ -321,7 +321,7 @@ int run(int argc, char* argv[])
...
@@ -321,7 +321,7 @@ int run(int argc, char* argv[])
printf
(
"arg1: verification (0=no, 1=yes)
\n
"
);
printf
(
"arg1: verification (0=no, 1=yes)
\n
"
);
printf
(
"arg2: initialization (0=no init, 1=integer value, 2=decimal value)
\n
"
);
printf
(
"arg2: initialization (0=no init, 1=integer value, 2=decimal value)
\n
"
);
printf
(
"arg3: time kernel (0=no, 1=yes)
\n
"
);
printf
(
"arg3: time kernel (0=no, 1=yes)
\n
"
);
printf
(
"arg4 to 10: M, N, K, O, G0, G1, G
2
\n
"
);
printf
(
"arg4 to 10: M, N, K, O, G0, G1
Q
, G
1KV
\n
"
);
printf
(
"arg11: p_drop
\n
"
);
printf
(
"arg11: p_drop
\n
"
);
printf
(
"arg12 to 13: input / output permute
\n
"
);
printf
(
"arg12 to 13: input / output permute
\n
"
);
exit
(
0
);
exit
(
0
);
...
@@ -340,8 +340,8 @@ int run(int argc, char* argv[])
...
@@ -340,8 +340,8 @@ int run(int argc, char* argv[])
std
::
cout
<<
"K: "
<<
K
<<
std
::
endl
;
std
::
cout
<<
"K: "
<<
K
<<
std
::
endl
;
std
::
cout
<<
"O: "
<<
O
<<
std
::
endl
;
std
::
cout
<<
"O: "
<<
O
<<
std
::
endl
;
std
::
cout
<<
"G0: "
<<
G0
<<
std
::
endl
;
std
::
cout
<<
"G0: "
<<
G0
<<
std
::
endl
;
std
::
cout
<<
"G1: "
<<
G1
<<
std
::
endl
;
std
::
cout
<<
"G1
Q
: "
<<
G1
Q
<<
std
::
endl
;
std
::
cout
<<
"G
2
: "
<<
G
2
<<
std
::
endl
;
std
::
cout
<<
"G
1KV
: "
<<
G
1KV
<<
std
::
endl
;
std
::
cout
<<
"alpha: "
<<
alpha
<<
std
::
endl
;
std
::
cout
<<
"alpha: "
<<
alpha
<<
std
::
endl
;
std
::
cout
<<
"input_permute: "
<<
input_permute
<<
std
::
endl
;
std
::
cout
<<
"input_permute: "
<<
input_permute
<<
std
::
endl
;
std
::
cout
<<
"output_permute: "
<<
output_permute
<<
std
::
endl
;
std
::
cout
<<
"output_permute: "
<<
output_permute
<<
std
::
endl
;
...
@@ -349,57 +349,57 @@ int run(int argc, char* argv[])
...
@@ -349,57 +349,57 @@ int run(int argc, char* argv[])
std
::
cout
<<
"seed: "
<<
seed
<<
std
::
endl
;
std
::
cout
<<
"seed: "
<<
seed
<<
std
::
endl
;
std
::
cout
<<
"offset: "
<<
offset
<<
std
::
endl
;
std
::
cout
<<
"offset: "
<<
offset
<<
std
::
endl
;
const
ck
::
index_t
BatchCount
=
G0
*
G1
;
const
ck
::
index_t
BatchCount
=
G0
*
G1
Q
;
std
::
vector
<
ck
::
index_t
>
q_gs_ms_ks_lengths
{
G0
,
G1
,
M
,
K
};
std
::
vector
<
ck
::
index_t
>
q_gs_ms_ks_lengths
{
G0
,
G1
Q
,
M
,
K
};
std
::
vector
<
ck
::
index_t
>
q_gs_ms_ks_strides
=
std
::
vector
<
ck
::
index_t
>
q_gs_ms_ks_strides
=
input_permute
input_permute
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
*
K
,
K
,
G1
*
K
,
1
}
// Q layout [G0, M, G1, K]
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
Q
*
K
,
K
,
G1
Q
*
K
,
1
}
// Q layout [G0, M, G1
Q
, K]
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
M
*
K
,
M
*
K
,
K
,
1
};
// Q layout [G0, G1, M, K]
:
std
::
vector
<
ck
::
index_t
>
{
G1
Q
*
M
*
K
,
M
*
K
,
K
,
1
};
// Q layout [G0, G1
Q
, M, K]
std
::
vector
<
ck
::
index_t
>
k_gs_ns_ks_lengths
{
G0
,
G
2
,
N
,
K
};
std
::
vector
<
ck
::
index_t
>
k_gs_ns_ks_lengths
{
G0
,
G
1KV
,
N
,
K
};
std
::
vector
<
ck
::
index_t
>
k_gs_ns_ks_strides
=
std
::
vector
<
ck
::
index_t
>
k_gs_ns_ks_strides
=
input_permute
input_permute
?
std
::
vector
<
ck
::
index_t
>
{
N
*
G
2
*
K
,
K
,
G
2
*
K
,
1
}
// K layout [G0, N, G
2
, K]
?
std
::
vector
<
ck
::
index_t
>
{
N
*
G
1KV
*
K
,
K
,
G
1KV
*
K
,
1
}
// K layout [G0, N, G
1KV
, K]
:
std
::
vector
<
ck
::
index_t
>
{
G
2
*
N
*
K
,
N
*
K
,
K
,
1
};
// K layout [G0, G
2
, N, K]
:
std
::
vector
<
ck
::
index_t
>
{
G
1KV
*
N
*
K
,
N
*
K
,
K
,
1
};
// K layout [G0, G
1KV
, N, K]
std
::
vector
<
ck
::
index_t
>
v_gs_os_ns_lengths
{
G0
,
G
2
,
O
,
N
};
std
::
vector
<
ck
::
index_t
>
v_gs_os_ns_lengths
{
G0
,
G
1KV
,
O
,
N
};
std
::
vector
<
ck
::
index_t
>
v_gs_os_ns_strides
=
std
::
vector
<
ck
::
index_t
>
v_gs_os_ns_strides
=
input_permute
input_permute
?
std
::
vector
<
ck
::
index_t
>
{
N
*
G
2
*
O
,
O
,
1
,
G
2
*
O
}
// V layout [G0, N, G
2
, O]
?
std
::
vector
<
ck
::
index_t
>
{
N
*
G
1KV
*
O
,
O
,
1
,
G
1KV
*
O
}
// V layout [G0, N, G
1KV
, O]
:
std
::
vector
<
ck
::
index_t
>
{
G
2
*
N
*
O
,
N
*
O
,
1
,
O
};
// V layout [G0, G
2
, N, O]
:
std
::
vector
<
ck
::
index_t
>
{
G
1KV
*
N
*
O
,
N
*
O
,
1
,
O
};
// V layout [G0, G
1KV
, N, O]
std
::
vector
<
ck
::
index_t
>
y_gs_ms_os_lengths
{
G0
,
G1
,
M
,
O
};
std
::
vector
<
ck
::
index_t
>
y_gs_ms_os_lengths
{
G0
,
G1
Q
,
M
,
O
};
std
::
vector
<
ck
::
index_t
>
y_gs_ms_os_strides
=
std
::
vector
<
ck
::
index_t
>
y_gs_ms_os_strides
=
output_permute
output_permute
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
*
O
,
O
,
G1
*
O
,
1
}
// Y layout [G0, M, G1, O]
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
Q
*
O
,
O
,
G1
Q
*
O
,
1
}
// Y layout [G0, M, G1
Q
, O]
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
M
*
O
,
M
*
O
,
O
,
1
};
// Y layout [G0, G1, M, O]
:
std
::
vector
<
ck
::
index_t
>
{
G1
Q
*
M
*
O
,
M
*
O
,
O
,
1
};
// Y layout [G0, G1
Q
, M, O]
std
::
vector
<
ck
::
index_t
>
z_gs_ms_ns_lengths
{
G0
,
G1
,
M
,
N
};
std
::
vector
<
ck
::
index_t
>
z_gs_ms_ns_lengths
{
G0
,
G1
Q
,
M
,
N
};
std
::
vector
<
ck
::
index_t
>
z_gs_ms_ns_strides
=
std
::
vector
<
ck
::
index_t
>
z_gs_ms_ns_strides
=
input_permute
input_permute
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
*
N
,
N
,
G1
*
N
,
1
}
// Z layout [G0, M, G1, N]
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
Q
*
N
,
N
,
G1
Q
*
N
,
1
}
// Z layout [G0, M, G1
Q
, N]
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
M
*
N
,
M
*
N
,
N
,
1
};
// Z layout [G0, G1, M, N]
:
std
::
vector
<
ck
::
index_t
>
{
G1
Q
*
M
*
N
,
M
*
N
,
N
,
1
};
// Z layout [G0, G1
Q
, M, N]
std
::
vector
<
ck
::
index_t
>
kgrad_gs_ns_ks_lengths
{
G0
,
G1
,
N
,
K
};
std
::
vector
<
ck
::
index_t
>
kgrad_gs_ns_ks_lengths
{
G0
,
G1
Q
,
N
,
K
};
std
::
vector
<
ck
::
index_t
>
kgrad_gs_ns_ks_strides
=
std
::
vector
<
ck
::
index_t
>
kgrad_gs_ns_ks_strides
=
input_permute
input_permute
?
std
::
vector
<
ck
::
index_t
>
{
N
*
G1
*
K
,
K
,
G1
*
K
,
1
}
// KGrad layout [G0, N, G1, K]
?
std
::
vector
<
ck
::
index_t
>
{
N
*
G1
Q
*
K
,
K
,
G1
Q
*
K
,
1
}
// KGrad layout [G0, N, G1
Q
, K]
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
N
*
K
,
N
*
K
,
K
,
1
};
// KGrad layout [G0, G1, N, K]
:
std
::
vector
<
ck
::
index_t
>
{
G1
Q
*
N
*
K
,
N
*
K
,
K
,
1
};
// KGrad layout [G0, G1
Q
, N, K]
std
::
vector
<
ck
::
index_t
>
vgrad_gs_os_ns_lengths
{
G0
,
G1
,
O
,
N
};
std
::
vector
<
ck
::
index_t
>
vgrad_gs_os_ns_lengths
{
G0
,
G1
Q
,
O
,
N
};
std
::
vector
<
ck
::
index_t
>
vgrad_gs_os_ns_strides
=
std
::
vector
<
ck
::
index_t
>
vgrad_gs_os_ns_strides
=
input_permute
input_permute
?
std
::
vector
<
ck
::
index_t
>
{
N
*
G1
*
O
,
O
,
1
,
G1
*
O
}
// VGrad layout [G0, N, G1, O]
?
std
::
vector
<
ck
::
index_t
>
{
N
*
G1
Q
*
O
,
O
,
1
,
G1
Q
*
O
}
// VGrad layout [G0, N, G1
Q
, O]
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
N
*
O
,
N
*
O
,
1
,
O
};
// VGrad layout [G0, G1, N, O]
:
std
::
vector
<
ck
::
index_t
>
{
G1
Q
*
N
*
O
,
N
*
O
,
1
,
O
};
// VGrad layout [G0, G1
Q
, N, O]
// The softmax stat log-sum-exp (LSE) is used to speed up softmax calculation in backward pass
// The softmax stat log-sum-exp (LSE) is used to speed up softmax calculation in backward pass
// Pi = exp(Si) / sum(exp(S0) + exp(S1) + ...)
// Pi = exp(Si) / sum(exp(S0) + exp(S1) + ...)
// = exp(Si) / exp(log(sum(exp() + ...)))
// = exp(Si) / exp(log(sum(exp() + ...)))
// = exp(Si - log(sum(exp() + ...)))
// = exp(Si - log(sum(exp() + ...)))
// ^^^^^^^^^^^^^^^^^^^^^
// ^^^^^^^^^^^^^^^^^^^^^
// LSE
// LSE
std
::
vector
<
ck
::
index_t
>
lse_gs_ms_lengths
{
G0
,
G1
,
M
};
std
::
vector
<
ck
::
index_t
>
lse_gs_ms_lengths
{
G0
,
G1
Q
,
M
};
std
::
vector
<
ck
::
index_t
>
lse_gs_ms_strides
{
G1
*
M
,
M
,
1
};
// LSE layout [G0, G1, M]
std
::
vector
<
ck
::
index_t
>
lse_gs_ms_strides
{
G1
Q
*
M
,
M
,
1
};
// LSE layout [G0, G1
Q
, M]
Tensor
<
InputDataType
>
q_gs_ms_ks
(
q_gs_ms_ks_lengths
,
q_gs_ms_ks_strides
);
Tensor
<
InputDataType
>
q_gs_ms_ks
(
q_gs_ms_ks_lengths
,
q_gs_ms_ks_strides
);
Tensor
<
InputDataType
>
k_gs_ns_ks
(
k_gs_ns_ks_lengths
,
k_gs_ns_ks_strides
);
Tensor
<
InputDataType
>
k_gs_ns_ks
(
k_gs_ns_ks_lengths
,
k_gs_ns_ks_strides
);
...
@@ -454,14 +454,14 @@ int run(int argc, char* argv[])
...
@@ -454,14 +454,14 @@ int run(int argc, char* argv[])
q_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_1
<
InputDataType
>
{
1
});
q_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_1
<
InputDataType
>
{
1
});
k_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
InputDataType
>
{});
k_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
InputDataType
>
{});
v_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
InputDataType
>
{});
v_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
InputDataType
>
{});
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_Sequential
<
2
>
{});
// dy[g0, g1, m, o]
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_Sequential
<
2
>
{});
// dy[g0, g1
q
, m, o]
// dO dot O = [0; 1; 2; ...]
// dO dot O = [0; 1; 2; ...]
break
;
break
;
case
6
:
case
6
:
q_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_1
<
InputDataType
>
{
1
});
q_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_1
<
InputDataType
>
{
1
});
k_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
InputDataType
>
{});
k_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
InputDataType
>
{});
v_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
InputDataType
>
{});
v_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
InputDataType
>
{});
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_Sequential
<
3
>
{});
// dy[g0, g1, m, o]
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_Sequential
<
3
>
{});
// dy[g0, g1
q
, m, o]
// assume mnko = 256
// assume mnko = 256
// P = softmax(QK) = 0.0039 * ones
// P = softmax(QK) = 0.0039 * ones
// O = P V = 0.0039 * ones
// O = P V = 0.0039 * ones
...
@@ -474,7 +474,8 @@ int run(int argc, char* argv[])
...
@@ -474,7 +474,8 @@ int run(int argc, char* argv[])
q_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_1
<
InputDataType
>
{
1
});
q_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_1
<
InputDataType
>
{
1
});
k_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
InputDataType
>
{});
k_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
InputDataType
>
{});
v_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
InputDataType
>
{});
v_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
InputDataType
>
{});
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_1
<
InputDataType
>
{
1
});
// dy[g0, g1, m, o]
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_1
<
InputDataType
>
{
1
});
// dy[g0, g1q, m, o]
// assume mnko = 256
// assume mnko = 256
// P = softmax(QK) = 0.0039 * ones
// P = softmax(QK) = 0.0039 * ones
// O = P V = 0.0039 * ones
// O = P V = 0.0039 * ones
...
@@ -496,20 +497,20 @@ int run(int argc, char* argv[])
...
@@ -496,20 +497,20 @@ int run(int argc, char* argv[])
Tensor
<
LSEDataType
>
lse_g_m
({
BatchCount
,
M
});
Tensor
<
LSEDataType
>
lse_g_m
({
BatchCount
,
M
});
q_gs_ms_ks
.
ForEach
(
q_gs_ms_ks
.
ForEach
(
[
&
](
auto
&
self
,
auto
idx
)
{
q_g_m_k
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
});
[
&
](
auto
&
self
,
auto
idx
)
{
q_g_m_k
(
idx
[
0
]
*
G1
Q
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
});
k_g_n_k
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
k_g_n_k
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
const
size_t
&
g0
=
idx
[
0
]
/
G1
;
const
size_t
&
g0
=
idx
[
0
]
/
G1
Q
;
const
size_t
&
g1
=
idx
[
0
]
%
G1
;
const
size_t
&
g1
q
=
idx
[
0
]
%
G1
Q
;
const
size_t
&
g
2
=
g1
/
(
G1
/
G
2
);
const
size_t
&
g
1kv
=
g1
q
/
(
G1
Q
/
G
1KV
);
self
(
idx
)
=
k_gs_ns_ks
(
g0
,
g
2
,
idx
[
1
],
idx
[
2
]);
self
(
idx
)
=
k_gs_ns_ks
(
g0
,
g
1kv
,
idx
[
1
],
idx
[
2
]);
});
});
v_g_n_o
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
v_g_n_o
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
const
size_t
&
g0
=
idx
[
0
]
/
G1
;
const
size_t
&
g0
=
idx
[
0
]
/
G1
Q
;
const
size_t
&
g1
=
idx
[
0
]
%
G1
;
const
size_t
&
g1
q
=
idx
[
0
]
%
G1
Q
;
const
size_t
&
g
2
=
g1
/
(
G1
/
G
2
);
const
size_t
&
g
1kv
=
g1
q
/
(
G1
Q
/
G
1KV
);
self
(
idx
)
=
v_gs_os_ns
(
g0
,
g
2
,
idx
[
2
],
idx
[
1
]);
self
(
idx
)
=
v_gs_os_ns
(
g0
,
g
1kv
,
idx
[
2
],
idx
[
1
]);
});
});
// qkv gradients have the same descriptor as with qkv
// qkv gradients have the same descriptor as with qkv
...
@@ -657,7 +658,7 @@ int run(int argc, char* argv[])
...
@@ -657,7 +658,7 @@ int run(int argc, char* argv[])
// copy z matirx data form device
// copy z matirx data form device
z_device_buf
.
FromDevice
(
z_gs_ms_ns
.
mData
.
data
());
z_device_buf
.
FromDevice
(
z_gs_ms_ns
.
mData
.
data
());
z_gs_ms_ns
.
ForEach
(
z_gs_ms_ns
.
ForEach
(
[
&
](
auto
&
self
,
auto
idx
)
{
z_g_m_n
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
});
[
&
](
auto
&
self
,
auto
idx
)
{
z_g_m_n
(
idx
[
0
]
*
G1
Q
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
});
// std::cout << "z_g_m_n ref:\n" << z_g_m_n;
// std::cout << "z_g_m_n ref:\n" << z_g_m_n;
bool
pass
=
true
;
bool
pass
=
true
;
...
@@ -677,10 +678,10 @@ int run(int argc, char* argv[])
...
@@ -677,10 +678,10 @@ int run(int argc, char* argv[])
p_dropout_in_uint8_t
,
p_dropout_in_uint8_t
,
rp_dropout
);
rp_dropout
);
y_gs_ms_os
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
y_gs_ms_os
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
self
(
idx
)
=
y_g_m_o
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
]);
self
(
idx
)
=
y_g_m_o
(
idx
[
0
]
*
G1
Q
+
idx
[
1
],
idx
[
2
],
idx
[
3
]);
});
});
lse_gs_ms
.
ForEach
(
lse_gs_ms
.
ForEach
(
[
&
](
auto
&
self
,
auto
idx
)
{
self
(
idx
)
=
lse_g_m
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
]);
});
[
&
](
auto
&
self
,
auto
idx
)
{
self
(
idx
)
=
lse_g_m
(
idx
[
0
]
*
G1
Q
+
idx
[
1
],
idx
[
2
]);
});
y_device_buf
.
ToDevice
(
y_gs_ms_os
.
mData
.
data
());
y_device_buf
.
ToDevice
(
y_gs_ms_os
.
mData
.
data
());
lse_device_buf
.
ToDevice
(
lse_gs_ms
.
mData
.
data
());
lse_device_buf
.
ToDevice
(
lse_gs_ms
.
mData
.
data
());
...
@@ -698,7 +699,7 @@ int run(int argc, char* argv[])
...
@@ -698,7 +699,7 @@ int run(int argc, char* argv[])
Tensor
<
InputDataType
>
ygrad_dot_y_g_m
({
BatchCount
,
M
});
Tensor
<
InputDataType
>
ygrad_dot_y_g_m
({
BatchCount
,
M
});
ygrad_gs_ms_os
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
ygrad_gs_ms_os
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
ygrad_g_m_o
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
ygrad_g_m_o
(
idx
[
0
]
*
G1
Q
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
});
});
#if PRINT_HOST
#if PRINT_HOST
...
@@ -817,26 +818,26 @@ int run(int argc, char* argv[])
...
@@ -817,26 +818,26 @@ int run(int argc, char* argv[])
// permute
// permute
qgrad_gs_ms_ks_host_result
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
qgrad_gs_ms_ks_host_result
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
const
size_t
&
g0
=
idx
[
0
];
const
size_t
&
g0
=
idx
[
0
];
const
size_t
&
g1
=
idx
[
1
];
const
size_t
&
g1
q
=
idx
[
1
];
const
size_t
g
=
g0
*
G1
+
g1
;
const
size_t
g
=
g0
*
G1
Q
+
g1
q
;
self
(
idx
)
=
qgrad_g_m_k
(
g
,
idx
[
2
],
idx
[
3
]);
self
(
idx
)
=
qgrad_g_m_k
(
g
,
idx
[
2
],
idx
[
3
]);
});
});
kgrad_gs_ns_ks_host_result
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
kgrad_gs_ns_ks_host_result
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
const
size_t
&
g0
=
idx
[
0
];
const
size_t
&
g0
=
idx
[
0
];
const
size_t
&
g1
=
idx
[
1
];
const
size_t
&
g1
q
=
idx
[
1
];
const
size_t
g
=
g0
*
G1
+
g1
;
const
size_t
g
=
g0
*
G1
Q
+
g1
q
;
self
(
idx
)
=
kgrad_g_n_k
(
g
,
idx
[
2
],
idx
[
3
]);
self
(
idx
)
=
kgrad_g_n_k
(
g
,
idx
[
2
],
idx
[
3
]);
});
});
vgrad_gs_os_ns_host_result
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
vgrad_gs_os_ns_host_result
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
const
size_t
&
g0
=
idx
[
0
];
const
size_t
&
g0
=
idx
[
0
];
const
size_t
&
g1
=
idx
[
1
];
const
size_t
&
g1
q
=
idx
[
1
];
const
size_t
g
=
g0
*
G1
+
g1
;
const
size_t
g
=
g0
*
G1
Q
+
g1
q
;
self
(
idx
)
=
vgrad_g_n_o
(
g
,
idx
[
3
],
idx
[
2
]);
self
(
idx
)
=
vgrad_g_n_o
(
g
,
idx
[
3
],
idx
[
2
]);
});
});
...
...
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_train_v2.cpp
View file @
10836d41
...
@@ -299,15 +299,15 @@ int run(int argc, char* argv[])
...
@@ -299,15 +299,15 @@ int run(int argc, char* argv[])
// Overall QKV matrices shape
// Overall QKV matrices shape
// y_g_m_o = Softmax(alpha * Q_g_m_k * K_g_k_n) * V_g_n_o
// y_g_m_o = Softmax(alpha * Q_g_m_k * K_g_k_n) * V_g_n_o
// y_g0_g1_m_o = reshape(y_g_m_o, [G0, G1, M, O])
// y_g0_g1
q
_m_o = reshape(y_g_m_o, [G0, G1
Q
, M, O])
// y_g0_m_g1_o = permute(y_g0_g1_m_o, [0, 2, 1, 3])
// y_g0_m_g1
q
_o = permute(y_g0_g1
q
_m_o, [0, 2, 1, 3])
ck
::
index_t
N
=
500
;
// 512
ck
::
index_t
N
=
500
;
// 512
ck
::
index_t
M
=
500
;
// 512
ck
::
index_t
M
=
500
;
// 512
ck
::
index_t
K
=
DIM
;
ck
::
index_t
K
=
DIM
;
ck
::
index_t
O
=
DIM
;
ck
::
index_t
O
=
DIM
;
ck
::
index_t
G0
=
4
;
ck
::
index_t
G0
=
4
;
ck
::
index_t
G1
=
6
;
// h_q
ck
::
index_t
G1
Q
=
6
;
// h_q
ck
::
index_t
G
2
=
6
;
// h_kv
ck
::
index_t
G
1KV
=
6
;
// h_kv
bool
input_permute
=
false
;
bool
input_permute
=
false
;
bool
output_permute
=
false
;
bool
output_permute
=
false
;
...
@@ -332,13 +332,13 @@ int run(int argc, char* argv[])
...
@@ -332,13 +332,13 @@ int run(int argc, char* argv[])
init_method
=
std
::
stoi
(
argv
[
2
]);
init_method
=
std
::
stoi
(
argv
[
2
]);
time_kernel
=
std
::
stoi
(
argv
[
3
]);
time_kernel
=
std
::
stoi
(
argv
[
3
]);
M
=
std
::
stoi
(
argv
[
4
]);
M
=
std
::
stoi
(
argv
[
4
]);
N
=
std
::
stoi
(
argv
[
5
]);
N
=
std
::
stoi
(
argv
[
5
]);
K
=
std
::
stoi
(
argv
[
6
]);
K
=
std
::
stoi
(
argv
[
6
]);
O
=
std
::
stoi
(
argv
[
7
]);
O
=
std
::
stoi
(
argv
[
7
]);
G0
=
std
::
stoi
(
argv
[
8
]);
G0
=
std
::
stoi
(
argv
[
8
]);
G1
=
std
::
stoi
(
argv
[
9
]);
G1
Q
=
std
::
stoi
(
argv
[
9
]);
G
2
=
std
::
stoi
(
argv
[
10
]);
G
1KV
=
std
::
stoi
(
argv
[
10
]);
p_drop
=
std
::
stof
(
argv
[
11
]);
p_drop
=
std
::
stof
(
argv
[
11
]);
...
@@ -350,7 +350,7 @@ int run(int argc, char* argv[])
...
@@ -350,7 +350,7 @@ int run(int argc, char* argv[])
printf
(
"arg1: verification (0=no, 1=yes)
\n
"
);
printf
(
"arg1: verification (0=no, 1=yes)
\n
"
);
printf
(
"arg2: initialization (0=no init, 1=integer value, 2=decimal value)
\n
"
);
printf
(
"arg2: initialization (0=no init, 1=integer value, 2=decimal value)
\n
"
);
printf
(
"arg3: time kernel (0=no, 1=yes)
\n
"
);
printf
(
"arg3: time kernel (0=no, 1=yes)
\n
"
);
printf
(
"arg4 to 10: M, N, K, O, G0, G1, G
2
\n
"
);
printf
(
"arg4 to 10: M, N, K, O, G0, G1
Q
, G
1KV
\n
"
);
printf
(
"arg11: p_drop
\n
"
);
printf
(
"arg11: p_drop
\n
"
);
printf
(
"arg12 to 13: input / output permute
\n
"
);
printf
(
"arg12 to 13: input / output permute
\n
"
);
exit
(
0
);
exit
(
0
);
...
@@ -369,8 +369,8 @@ int run(int argc, char* argv[])
...
@@ -369,8 +369,8 @@ int run(int argc, char* argv[])
std
::
cout
<<
"K: "
<<
K
<<
std
::
endl
;
std
::
cout
<<
"K: "
<<
K
<<
std
::
endl
;
std
::
cout
<<
"O: "
<<
O
<<
std
::
endl
;
std
::
cout
<<
"O: "
<<
O
<<
std
::
endl
;
std
::
cout
<<
"G0: "
<<
G0
<<
std
::
endl
;
std
::
cout
<<
"G0: "
<<
G0
<<
std
::
endl
;
std
::
cout
<<
"G1: "
<<
G1
<<
std
::
endl
;
std
::
cout
<<
"G1
Q
: "
<<
G1
Q
<<
std
::
endl
;
std
::
cout
<<
"G
2
: "
<<
G
2
<<
std
::
endl
;
std
::
cout
<<
"G
1KV
: "
<<
G
1KV
<<
std
::
endl
;
std
::
cout
<<
"alpha: "
<<
alpha
<<
std
::
endl
;
std
::
cout
<<
"alpha: "
<<
alpha
<<
std
::
endl
;
std
::
cout
<<
"input_permute: "
<<
input_permute
<<
std
::
endl
;
std
::
cout
<<
"input_permute: "
<<
input_permute
<<
std
::
endl
;
std
::
cout
<<
"output_permute: "
<<
output_permute
<<
std
::
endl
;
std
::
cout
<<
"output_permute: "
<<
output_permute
<<
std
::
endl
;
...
@@ -378,57 +378,57 @@ int run(int argc, char* argv[])
...
@@ -378,57 +378,57 @@ int run(int argc, char* argv[])
std
::
cout
<<
"seed: "
<<
seed
<<
std
::
endl
;
std
::
cout
<<
"seed: "
<<
seed
<<
std
::
endl
;
std
::
cout
<<
"offset: "
<<
offset
<<
std
::
endl
;
std
::
cout
<<
"offset: "
<<
offset
<<
std
::
endl
;
const
ck
::
index_t
BatchCount
=
G0
*
G1
;
const
ck
::
index_t
BatchCount
=
G0
*
G1
Q
;
std
::
vector
<
ck
::
index_t
>
q_gs_ms_ks_lengths
{
G0
,
G1
,
M
,
K
};
std
::
vector
<
ck
::
index_t
>
q_gs_ms_ks_lengths
{
G0
,
G1
Q
,
M
,
K
};
std
::
vector
<
ck
::
index_t
>
q_gs_ms_ks_strides
=
std
::
vector
<
ck
::
index_t
>
q_gs_ms_ks_strides
=
input_permute
input_permute
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
*
K
,
K
,
G1
*
K
,
1
}
// Q layout [G0, M, G1, K]
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
Q
*
K
,
K
,
G1
Q
*
K
,
1
}
// Q layout [G0, M, G1
Q
, K]
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
M
*
K
,
M
*
K
,
K
,
1
};
// Q layout [G0, G1, M, K]
:
std
::
vector
<
ck
::
index_t
>
{
G1
Q
*
M
*
K
,
M
*
K
,
K
,
1
};
// Q layout [G0, G1
Q
, M, K]
std
::
vector
<
ck
::
index_t
>
k_gs_ns_ks_lengths
{
G0
,
G
2
,
N
,
K
};
std
::
vector
<
ck
::
index_t
>
k_gs_ns_ks_lengths
{
G0
,
G
1KV
,
N
,
K
};
std
::
vector
<
ck
::
index_t
>
k_gs_ns_ks_strides
=
std
::
vector
<
ck
::
index_t
>
k_gs_ns_ks_strides
=
input_permute
input_permute
?
std
::
vector
<
ck
::
index_t
>
{
N
*
G
2
*
K
,
K
,
G
2
*
K
,
1
}
// K layout [G0, N, G
2
, K]
?
std
::
vector
<
ck
::
index_t
>
{
N
*
G
1KV
*
K
,
K
,
G
1KV
*
K
,
1
}
// K layout [G0, N, G
1KV
, K]
:
std
::
vector
<
ck
::
index_t
>
{
G
2
*
N
*
K
,
N
*
K
,
K
,
1
};
// K layout [G0, G
2
, N, K]
:
std
::
vector
<
ck
::
index_t
>
{
G
1KV
*
N
*
K
,
N
*
K
,
K
,
1
};
// K layout [G0, G
1KV
, N, K]
std
::
vector
<
ck
::
index_t
>
v_gs_os_ns_lengths
{
G0
,
G
2
,
O
,
N
};
std
::
vector
<
ck
::
index_t
>
v_gs_os_ns_lengths
{
G0
,
G
1KV
,
O
,
N
};
std
::
vector
<
ck
::
index_t
>
v_gs_os_ns_strides
=
std
::
vector
<
ck
::
index_t
>
v_gs_os_ns_strides
=
input_permute
input_permute
?
std
::
vector
<
ck
::
index_t
>
{
N
*
G
2
*
O
,
O
,
1
,
G
2
*
O
}
// V layout [G0, N, G
2
, O]
?
std
::
vector
<
ck
::
index_t
>
{
N
*
G
1KV
*
O
,
O
,
1
,
G
1KV
*
O
}
// V layout [G0, N, G
1KV
, O]
:
std
::
vector
<
ck
::
index_t
>
{
G
2
*
N
*
O
,
N
*
O
,
1
,
O
};
// V layout [G0, G
2
, N, O]
:
std
::
vector
<
ck
::
index_t
>
{
G
1KV
*
N
*
O
,
N
*
O
,
1
,
O
};
// V layout [G0, G
1KV
, N, O]
std
::
vector
<
ck
::
index_t
>
y_gs_ms_os_lengths
{
G0
,
G1
,
M
,
O
};
std
::
vector
<
ck
::
index_t
>
y_gs_ms_os_lengths
{
G0
,
G1
Q
,
M
,
O
};
std
::
vector
<
ck
::
index_t
>
y_gs_ms_os_strides
=
std
::
vector
<
ck
::
index_t
>
y_gs_ms_os_strides
=
output_permute
output_permute
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
*
O
,
O
,
G1
*
O
,
1
}
// Y layout [G0, M, G1, O]
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
Q
*
O
,
O
,
G1
Q
*
O
,
1
}
// Y layout [G0, M, G1
Q
, O]
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
M
*
O
,
M
*
O
,
O
,
1
};
// Y layout [G0, G1, M, O]
:
std
::
vector
<
ck
::
index_t
>
{
G1
Q
*
M
*
O
,
M
*
O
,
O
,
1
};
// Y layout [G0, G1
Q
, M, O]
std
::
vector
<
ck
::
index_t
>
z_gs_ms_ns_lengths
{
G0
,
G1
,
M
,
N
};
std
::
vector
<
ck
::
index_t
>
z_gs_ms_ns_lengths
{
G0
,
G1
Q
,
M
,
N
};
std
::
vector
<
ck
::
index_t
>
z_gs_ms_ns_strides
=
std
::
vector
<
ck
::
index_t
>
z_gs_ms_ns_strides
=
input_permute
input_permute
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
*
N
,
N
,
G1
*
N
,
1
}
// Z layout [G0, M, G1, N]
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
Q
*
N
,
N
,
G1
Q
*
N
,
1
}
// Z layout [G0, M, G1
Q
, N]
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
M
*
N
,
M
*
N
,
N
,
1
};
// Z layout [G0, G1, M, N]
:
std
::
vector
<
ck
::
index_t
>
{
G1
Q
*
M
*
N
,
M
*
N
,
N
,
1
};
// Z layout [G0, G1
Q
, M, N]
std
::
vector
<
ck
::
index_t
>
kgrad_gs_ns_ks_lengths
{
G0
,
G1
,
N
,
K
};
std
::
vector
<
ck
::
index_t
>
kgrad_gs_ns_ks_lengths
{
G0
,
G1
Q
,
N
,
K
};
std
::
vector
<
ck
::
index_t
>
kgrad_gs_ns_ks_strides
=
std
::
vector
<
ck
::
index_t
>
kgrad_gs_ns_ks_strides
=
input_permute
input_permute
?
std
::
vector
<
ck
::
index_t
>
{
N
*
G1
*
K
,
K
,
G1
*
K
,
1
}
// KGrad layout [G0, N, G1, K]
?
std
::
vector
<
ck
::
index_t
>
{
N
*
G1
Q
*
K
,
K
,
G1
Q
*
K
,
1
}
// KGrad layout [G0, N, G1
Q
, K]
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
N
*
K
,
N
*
K
,
K
,
1
};
// KGrad layout [G0, G1, N, K]
:
std
::
vector
<
ck
::
index_t
>
{
G1
Q
*
N
*
K
,
N
*
K
,
K
,
1
};
// KGrad layout [G0, G1
Q
, N, K]
std
::
vector
<
ck
::
index_t
>
vgrad_gs_os_ns_lengths
{
G0
,
G1
,
O
,
N
};
std
::
vector
<
ck
::
index_t
>
vgrad_gs_os_ns_lengths
{
G0
,
G1
Q
,
O
,
N
};
std
::
vector
<
ck
::
index_t
>
vgrad_gs_os_ns_strides
=
std
::
vector
<
ck
::
index_t
>
vgrad_gs_os_ns_strides
=
input_permute
input_permute
?
std
::
vector
<
ck
::
index_t
>
{
N
*
G1
*
O
,
O
,
1
,
G1
*
O
}
// VGrad layout [G0, N, G1, O]
?
std
::
vector
<
ck
::
index_t
>
{
N
*
G1
Q
*
O
,
O
,
1
,
G1
Q
*
O
}
// VGrad layout [G0, N, G1
Q
, O]
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
N
*
O
,
N
*
O
,
1
,
O
};
// VGrad layout [G0, G1, N, O]
:
std
::
vector
<
ck
::
index_t
>
{
G1
Q
*
N
*
O
,
N
*
O
,
1
,
O
};
// VGrad layout [G0, G1
Q
, N, O]
// The softmax stat log-sum-exp (LSE) is used to speed up softmax calculation in backward pass
// The softmax stat log-sum-exp (LSE) is used to speed up softmax calculation in backward pass
// Pi = exp(Si) / sum(exp(S0) + exp(S1) + ...)
// Pi = exp(Si) / sum(exp(S0) + exp(S1) + ...)
// = exp(Si) / exp(log(sum(exp() + ...)))
// = exp(Si) / exp(log(sum(exp() + ...)))
// = exp(Si - log(sum(exp() + ...)))
// = exp(Si - log(sum(exp() + ...)))
// ^^^^^^^^^^^^^^^^^^^^^
// ^^^^^^^^^^^^^^^^^^^^^
// LSE
// LSE
std
::
vector
<
ck
::
index_t
>
lse_gs_ms_lengths
{
G0
,
G1
,
M
};
std
::
vector
<
ck
::
index_t
>
lse_gs_ms_lengths
{
G0
,
G1
Q
,
M
};
std
::
vector
<
ck
::
index_t
>
lse_gs_ms_strides
{
G1
*
M
,
M
,
1
};
// LSE layout [G0, G1, M]
std
::
vector
<
ck
::
index_t
>
lse_gs_ms_strides
{
G1
Q
*
M
,
M
,
1
};
// LSE layout [G0, G1
Q
, M]
Tensor
<
InputDataType
>
q_gs_ms_ks
(
q_gs_ms_ks_lengths
,
q_gs_ms_ks_strides
);
Tensor
<
InputDataType
>
q_gs_ms_ks
(
q_gs_ms_ks_lengths
,
q_gs_ms_ks_strides
);
Tensor
<
InputDataType
>
k_gs_ns_ks
(
k_gs_ns_ks_lengths
,
k_gs_ns_ks_strides
);
Tensor
<
InputDataType
>
k_gs_ns_ks
(
k_gs_ns_ks_lengths
,
k_gs_ns_ks_strides
);
...
@@ -484,14 +484,14 @@ int run(int argc, char* argv[])
...
@@ -484,14 +484,14 @@ int run(int argc, char* argv[])
q_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_1
<
InputDataType
>
{
1
});
q_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_1
<
InputDataType
>
{
1
});
k_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
InputDataType
>
{});
k_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
InputDataType
>
{});
v_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
InputDataType
>
{});
v_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
InputDataType
>
{});
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_Sequential
<
2
>
{});
// dy[g0, g1, m, o]
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_Sequential
<
2
>
{});
// dy[g0, g1
q
, m, o]
// dO dot O = [0; 1; 2; ...]
// dO dot O = [0; 1; 2; ...]
break
;
break
;
case
6
:
case
6
:
q_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_1
<
InputDataType
>
{
1
});
q_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_1
<
InputDataType
>
{
1
});
k_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
InputDataType
>
{});
k_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
InputDataType
>
{});
v_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
InputDataType
>
{});
v_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
InputDataType
>
{});
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_Sequential
<
3
>
{});
// dy[g0, g1, m, o]
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_Sequential
<
3
>
{});
// dy[g0, g1
q
, m, o]
// assume mnko = 256
// assume mnko = 256
// P = softmax(QK) = 0.0039 * ones
// P = softmax(QK) = 0.0039 * ones
// O = P V = 0.0039 * ones
// O = P V = 0.0039 * ones
...
@@ -504,7 +504,8 @@ int run(int argc, char* argv[])
...
@@ -504,7 +504,8 @@ int run(int argc, char* argv[])
q_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_1
<
InputDataType
>
{
1
});
q_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_1
<
InputDataType
>
{
1
});
k_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
InputDataType
>
{});
k_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
InputDataType
>
{});
v_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
InputDataType
>
{});
v_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
InputDataType
>
{});
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_1
<
InputDataType
>
{
1
});
// dy[g0, g1, m, o]
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_1
<
InputDataType
>
{
1
});
// dy[g0, g1q, m, o]
// assume mnko = 256
// assume mnko = 256
// P = softmax(QK) = 0.0039 * ones
// P = softmax(QK) = 0.0039 * ones
// O = P V = 0.0039 * ones
// O = P V = 0.0039 * ones
...
@@ -820,24 +821,24 @@ int run(int argc, char* argv[])
...
@@ -820,24 +821,24 @@ int run(int argc, char* argv[])
}
}
q_gs_ms_ks
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
q_gs_ms_ks
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
q_g_m_k
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
q_g_m_k
(
idx
[
0
]
*
G1
Q
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
});
});
k_g_n_k
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
k_g_n_k
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
const
size_t
&
g0
=
idx
[
0
]
/
G1
;
const
size_t
&
g0
=
idx
[
0
]
/
G1
Q
;
const
size_t
&
g1
=
idx
[
0
]
%
G1
;
const
size_t
&
g1
q
=
idx
[
0
]
%
G1
Q
;
const
size_t
&
g
2
=
g1
/
(
G1
/
G
2
);
const
size_t
&
g
1kv
=
g1
q
/
(
G1
Q
/
G
1KV
);
self
(
idx
)
=
k_gs_ns_ks
(
g0
,
g
2
,
idx
[
1
],
idx
[
2
]);
self
(
idx
)
=
k_gs_ns_ks
(
g0
,
g
1kv
,
idx
[
1
],
idx
[
2
]);
});
});
v_g_n_o
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
v_g_n_o
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
const
size_t
&
g0
=
idx
[
0
]
/
G1
;
const
size_t
&
g0
=
idx
[
0
]
/
G1
Q
;
const
size_t
&
g1
=
idx
[
0
]
%
G1
;
const
size_t
&
g1
q
=
idx
[
0
]
%
G1
Q
;
const
size_t
&
g
2
=
g1
/
(
G1
/
G
2
);
const
size_t
&
g
1kv
=
g1
q
/
(
G1
Q
/
G
1KV
);
self
(
idx
)
=
v_gs_os_ns
(
g0
,
g
2
,
idx
[
2
],
idx
[
1
]);
self
(
idx
)
=
v_gs_os_ns
(
g0
,
g
1kv
,
idx
[
2
],
idx
[
1
]);
});
});
z_fwd_gs_ms_ns
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
z_fwd_gs_ms_ns
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
z_fwd_g_m_n
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
z_fwd_g_m_n
(
idx
[
0
]
*
G1
Q
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
});
});
run_attention_fwd_host
(
q_g_m_k
,
run_attention_fwd_host
(
q_g_m_k
,
...
@@ -854,10 +855,10 @@ int run(int argc, char* argv[])
...
@@ -854,10 +855,10 @@ int run(int argc, char* argv[])
rp_dropout
);
rp_dropout
);
ygrad_gs_ms_os
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
ygrad_gs_ms_os
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
ygrad_g_m_o
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
ygrad_g_m_o
(
idx
[
0
]
*
G1
Q
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
});
});
z_bwd_gs_ms_ns
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
z_bwd_gs_ms_ns
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
z_bwd_g_m_n
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
z_bwd_g_m_n
(
idx
[
0
]
*
G1
Q
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
});
});
#if PRINT_HOST
#if PRINT_HOST
...
@@ -960,42 +961,42 @@ int run(int argc, char* argv[])
...
@@ -960,42 +961,42 @@ int run(int argc, char* argv[])
// permute
// permute
y_gs_ms_os_host_result
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
y_gs_ms_os_host_result
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
const
size_t
&
g0
=
idx
[
0
];
const
size_t
&
g0
=
idx
[
0
];
const
size_t
&
g1
=
idx
[
1
];
const
size_t
&
g1
q
=
idx
[
1
];
const
size_t
g
=
g0
*
G1
+
g1
;
const
size_t
g
=
g0
*
G1
Q
+
g1
q
;
self
(
idx
)
=
y_g_m_o
(
g
,
idx
[
2
],
idx
[
3
]);
self
(
idx
)
=
y_g_m_o
(
g
,
idx
[
2
],
idx
[
3
]);
});
});
lse_gs_ms_host_result
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
lse_gs_ms_host_result
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
const
size_t
&
g0
=
idx
[
0
];
const
size_t
&
g0
=
idx
[
0
];
const
size_t
&
g1
=
idx
[
1
];
const
size_t
&
g1
q
=
idx
[
1
];
const
size_t
g
=
g0
*
G1
+
g1
;
const
size_t
g
=
g0
*
G1
Q
+
g1
q
;
self
(
idx
)
=
lse_g_m
(
g
,
idx
[
2
]);
self
(
idx
)
=
lse_g_m
(
g
,
idx
[
2
]);
});
});
qgrad_gs_ms_ks_host_result
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
qgrad_gs_ms_ks_host_result
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
const
size_t
&
g0
=
idx
[
0
];
const
size_t
&
g0
=
idx
[
0
];
const
size_t
&
g1
=
idx
[
1
];
const
size_t
&
g1
q
=
idx
[
1
];
const
size_t
g
=
g0
*
G1
+
g1
;
const
size_t
g
=
g0
*
G1
Q
+
g1
q
;
self
(
idx
)
=
qgrad_g_m_k
(
g
,
idx
[
2
],
idx
[
3
]);
self
(
idx
)
=
qgrad_g_m_k
(
g
,
idx
[
2
],
idx
[
3
]);
});
});
kgrad_gs_ns_ks_host_result
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
kgrad_gs_ns_ks_host_result
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
const
size_t
&
g0
=
idx
[
0
];
const
size_t
&
g0
=
idx
[
0
];
const
size_t
&
g1
=
idx
[
1
];
const
size_t
&
g1
q
=
idx
[
1
];
const
size_t
g
=
g0
*
G1
+
g1
;
const
size_t
g
=
g0
*
G1
Q
+
g1
q
;
self
(
idx
)
=
kgrad_g_n_k
(
g
,
idx
[
2
],
idx
[
3
]);
self
(
idx
)
=
kgrad_g_n_k
(
g
,
idx
[
2
],
idx
[
3
]);
});
});
vgrad_gs_os_ns_host_result
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
vgrad_gs_os_ns_host_result
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
const
size_t
&
g0
=
idx
[
0
];
const
size_t
&
g0
=
idx
[
0
];
const
size_t
&
g1
=
idx
[
1
];
const
size_t
&
g1
q
=
idx
[
1
];
const
size_t
g
=
g0
*
G1
+
g1
;
const
size_t
g
=
g0
*
G1
Q
+
g1
q
;
self
(
idx
)
=
vgrad_g_n_o
(
g
,
idx
[
3
],
idx
[
2
]);
self
(
idx
)
=
vgrad_g_n_o
(
g
,
idx
[
3
],
idx
[
2
]);
});
});
...
...
example/32_batched_gemm_scale_softmax_gemm/grouped_multihead_attention_backward_v2.cpp
View file @
10836d41
...
@@ -268,11 +268,11 @@ int run(int argc, char* argv[])
...
@@ -268,11 +268,11 @@ int run(int argc, char* argv[])
// Overall QKV matrices shape
// Overall QKV matrices shape
// y_g_m_o = Softmax(alpha * Q_g_m_k * K_g_k_n) * V_g_n_o
// y_g_m_o = Softmax(alpha * Q_g_m_k * K_g_k_n) * V_g_n_o
// y_g0_g1_m_o = reshape(y_g_m_o, [G0, G1, M, O])
// y_g0_g1
q
_m_o = reshape(y_g_m_o, [G0, G1
Q
, M, O])
// y_g0_m_g1_o = permute(y_g0_g1_m_o, [0, 2, 1, 3])
// y_g0_m_g1
q
_o = permute(y_g0_g1
q
_m_o, [0, 2, 1, 3])
float
alpha
=
1.
f
/
std
::
sqrt
(
DIM
);
float
alpha
=
1.
f
/
std
::
sqrt
(
DIM
);
float
p_drop
=
0.0
;
float
p_drop
=
0.0
;
int
h_ratio
=
1
;
// G1 / G
2
int
h_ratio
=
1
;
// G1
Q
/ G
1KV
bool
input_permute
=
true
;
bool
input_permute
=
true
;
bool
output_permute
=
true
;
bool
output_permute
=
true
;
...
@@ -369,61 +369,65 @@ int run(int argc, char* argv[])
...
@@ -369,61 +369,65 @@ int run(int argc, char* argv[])
std
::
size_t
flop
=
0
,
num_byte
=
0
;
std
::
size_t
flop
=
0
,
num_byte
=
0
;
for
(
std
::
size_t
i
=
0
;
i
<
group_count
;
i
++
)
for
(
std
::
size_t
i
=
0
;
i
<
group_count
;
i
++
)
{
{
int
M
=
128
*
(
rand
()
%
8
)
+
(
rand
()
%
128
);
int
M
=
128
*
(
rand
()
%
8
)
+
(
rand
()
%
128
);
int
N
=
128
*
(
rand
()
%
8
)
+
(
rand
()
%
128
);
int
N
=
128
*
(
rand
()
%
8
)
+
(
rand
()
%
128
);
int
K
=
DIM
;
int
K
=
DIM
;
int
O
=
DIM
;
int
O
=
DIM
;
int
G0
=
rand
()
%
4
+
1
;
int
G0
=
rand
()
%
4
+
1
;
int
G
2
=
rand
()
%
4
+
1
;
int
G
1KV
=
rand
()
%
4
+
1
;
int
G1
=
G
2
*
h_ratio
;
int
G1
Q
=
G
1KV
*
h_ratio
;
std
::
vector
<
ck
::
index_t
>
q_gs_ms_ks_lengths
{
G0
,
G1
,
M
,
K
};
std
::
vector
<
ck
::
index_t
>
q_gs_ms_ks_lengths
{
G0
,
G1
Q
,
M
,
K
};
std
::
vector
<
ck
::
index_t
>
q_gs_ms_ks_strides
=
std
::
vector
<
ck
::
index_t
>
q_gs_ms_ks_strides
=
input_permute
input_permute
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
*
K
,
K
,
G1
*
K
,
1
}
// Q layout [G0, M, G1, K]
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
Q
*
K
,
K
,
G1
Q
*
K
,
1
}
// Q layout [G0, M, G1
Q
, K]
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
M
*
K
,
M
*
K
,
K
,
1
};
// Q layout [G0, G1, M, K]
:
std
::
vector
<
ck
::
index_t
>
{
G1
Q
*
M
*
K
,
M
*
K
,
K
,
1
};
// Q layout [G0, G1
Q
, M, K]
std
::
vector
<
ck
::
index_t
>
k_gs_ns_ks_lengths
{
G0
,
G
2
,
N
,
K
};
std
::
vector
<
ck
::
index_t
>
k_gs_ns_ks_lengths
{
G0
,
G
1KV
,
N
,
K
};
std
::
vector
<
ck
::
index_t
>
k_gs_ns_ks_strides
=
std
::
vector
<
ck
::
index_t
>
k_gs_ns_ks_strides
=
input_permute
input_permute
?
std
::
vector
<
ck
::
index_t
>
{
N
*
G2
*
K
,
K
,
G2
*
K
,
1
}
// K layout [G0, N, G2, K]
?
std
::
vector
<
ck
::
index_t
>
{
N
*
G1KV
*
K
,
K
,
G1KV
*
K
,
1
}
:
std
::
vector
<
ck
::
index_t
>
{
G2
*
N
*
K
,
N
*
K
,
K
,
1
};
// K layout [G0, G2, N, K]
// K layout [G0, N, G1KV, K]
:
std
::
vector
<
ck
::
index_t
>
{
G1KV
*
N
*
K
,
N
*
K
,
K
,
1
};
// K layout [G0, G1KV, N, K]
std
::
vector
<
ck
::
index_t
>
v_gs_os_ns_lengths
{
G0
,
G
2
,
O
,
N
};
std
::
vector
<
ck
::
index_t
>
v_gs_os_ns_lengths
{
G0
,
G
1KV
,
O
,
N
};
std
::
vector
<
ck
::
index_t
>
v_gs_os_ns_strides
=
std
::
vector
<
ck
::
index_t
>
v_gs_os_ns_strides
=
input_permute
input_permute
?
std
::
vector
<
ck
::
index_t
>
{
N
*
G2
*
O
,
O
,
1
,
G2
*
O
}
// V layout [G0, N, G2, O]
?
std
::
vector
<
ck
::
index_t
>
{
N
*
G1KV
*
O
,
O
,
1
,
G1KV
*
O
}
:
std
::
vector
<
ck
::
index_t
>
{
G2
*
N
*
O
,
N
*
O
,
1
,
O
};
// V layout [G0, G2, N, O]
// V layout [G0, N, G1KV, O]
:
std
::
vector
<
ck
::
index_t
>
{
G1KV
*
N
*
O
,
N
*
O
,
1
,
O
};
// V layout [G0, G1KV, N, O]
std
::
vector
<
ck
::
index_t
>
y_gs_ms_os_lengths
{
G0
,
G1
,
M
,
O
};
std
::
vector
<
ck
::
index_t
>
y_gs_ms_os_lengths
{
G0
,
G1
Q
,
M
,
O
};
std
::
vector
<
ck
::
index_t
>
y_gs_ms_os_strides
=
std
::
vector
<
ck
::
index_t
>
y_gs_ms_os_strides
=
output_permute
output_permute
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
*
O
,
O
,
G1
*
O
,
1
}
// Y layout [G0, M, G1, O]
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
Q
*
O
,
O
,
G1
Q
*
O
,
1
}
// Y layout [G0, M, G1
Q
, O]
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
M
*
O
,
M
*
O
,
O
,
1
};
// Y layout [G0, G1, M, O]
:
std
::
vector
<
ck
::
index_t
>
{
G1
Q
*
M
*
O
,
M
*
O
,
O
,
1
};
// Y layout [G0, G1
Q
, M, O]
std
::
vector
<
ck
::
index_t
>
z_gs_ms_ns_lengths
{
G0
,
G1
,
M
,
N
};
std
::
vector
<
ck
::
index_t
>
z_gs_ms_ns_lengths
{
G0
,
G1
Q
,
M
,
N
};
std
::
vector
<
ck
::
index_t
>
z_gs_ms_ns_strides
=
std
::
vector
<
ck
::
index_t
>
z_gs_ms_ns_strides
=
input_permute
input_permute
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
*
N
,
N
,
G1
*
N
,
1
}
// Z layout [G0, M, G1, N]
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
Q
*
N
,
N
,
G1
Q
*
N
,
1
}
// Z layout [G0, M, G1
Q
, N]
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
M
*
N
,
M
*
N
,
N
,
1
};
// Z layout [G0, G1, M, N]
:
std
::
vector
<
ck
::
index_t
>
{
G1
Q
*
M
*
N
,
M
*
N
,
N
,
1
};
// Z layout [G0, G1
Q
, M, N]
std
::
vector
<
ck
::
index_t
>
kgrad_gs_ns_ks_lengths
{
G0
,
G1
,
N
,
K
};
std
::
vector
<
ck
::
index_t
>
kgrad_gs_ns_ks_lengths
{
G0
,
G1
Q
,
N
,
K
};
std
::
vector
<
ck
::
index_t
>
kgrad_gs_ns_ks_strides
=
std
::
vector
<
ck
::
index_t
>
kgrad_gs_ns_ks_strides
=
input_permute
input_permute
?
std
::
vector
<
ck
::
index_t
>
{
N
*
G1Q
*
K
,
K
,
G1Q
*
K
,
1
}
?
std
::
vector
<
ck
::
index_t
>
{
N
*
G1
*
K
,
K
,
G1
*
K
,
1
}
// KGrad layout [G0, N, G1, K]
// KGrad layout [G0, N, G1Q, K]
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
N
*
K
,
N
*
K
,
K
,
1
};
// KGrad layout [G0, G1, N, K]
:
std
::
vector
<
ck
::
index_t
>
{
G1Q
*
N
*
K
,
N
*
K
,
K
,
1
};
// KGrad layout [G0, G1Q, N, K]
std
::
vector
<
ck
::
index_t
>
vgrad_gs_os_ns_lengths
{
G0
,
G1
,
O
,
N
};
std
::
vector
<
ck
::
index_t
>
vgrad_gs_os_ns_lengths
{
G0
,
G1
Q
,
O
,
N
};
std
::
vector
<
ck
::
index_t
>
vgrad_gs_os_ns_strides
=
std
::
vector
<
ck
::
index_t
>
vgrad_gs_os_ns_strides
=
input_permute
input_permute
?
std
::
vector
<
ck
::
index_t
>
{
N
*
G1Q
*
O
,
O
,
1
,
G1Q
*
O
}
?
std
::
vector
<
ck
::
index_t
>
{
N
*
G1
*
O
,
O
,
1
,
G1
*
O
}
// VGrad layout [G0, N, G1, O]
// VGrad layout [G0, N, G1Q, O]
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
N
*
O
,
N
*
O
,
1
,
O
};
// VGrad layout [G0, G1, N, O]
:
std
::
vector
<
ck
::
index_t
>
{
G1Q
*
N
*
O
,
N
*
O
,
1
,
O
};
// VGrad layout [G0, G1Q, N, O]
// The softmax stat log-sum-exp (LSE) is used to speed up softmax calculation in backward
// The softmax stat log-sum-exp (LSE) is used to speed up softmax calculation in backward
// pass Pi = exp(Si) / sum(exp(S0) + exp(S1) + ...)
// pass Pi = exp(Si) / sum(exp(S0) + exp(S1) + ...)
// = exp(Si) / exp(log(sum(exp() + ...)))
// = exp(Si) / exp(log(sum(exp() + ...)))
// = exp(Si - log(sum(exp() + ...)))
// = exp(Si - log(sum(exp() + ...)))
// ^^^^^^^^^^^^^^^^^^^^^
// ^^^^^^^^^^^^^^^^^^^^^
// LSE
// LSE
std
::
vector
<
ck
::
index_t
>
lse_gs_ms_lengths
{
G0
,
G1
,
M
};
std
::
vector
<
ck
::
index_t
>
lse_gs_ms_lengths
{
G0
,
G1
Q
,
M
};
std
::
vector
<
ck
::
index_t
>
lse_gs_ms_strides
{
G1
*
M
,
M
,
1
};
// LSE layout [G0, G1, M]
std
::
vector
<
ck
::
index_t
>
lse_gs_ms_strides
{
G1
Q
*
M
,
M
,
1
};
// LSE layout [G0, G1
Q
, M]
problem_descs
.
push_back
({
problem_descs
.
push_back
({
q_gs_ms_ks_lengths
,
q_gs_ms_ks_lengths
,
q_gs_ms_ks_strides
,
q_gs_ms_ks_strides
,
...
@@ -447,7 +451,7 @@ int run(int argc, char* argv[])
...
@@ -447,7 +451,7 @@ int run(int argc, char* argv[])
{},
// std::array<std::vector<ck::index_t>, 1>{acc1_biases_gs_ms_os_strides},
{},
// std::array<std::vector<ck::index_t>, 1>{acc1_biases_gs_ms_os_strides},
});
});
int
BatchCount
=
G0
*
G1
;
int
BatchCount
=
G0
*
G1
Q
;
flop
+=
(
size_t
(
3
)
*
M
*
N
*
K
+
size_t
(
2
)
*
M
*
N
*
O
)
*
2
*
BatchCount
;
flop
+=
(
size_t
(
3
)
*
M
*
N
*
K
+
size_t
(
2
)
*
M
*
N
*
O
)
*
2
*
BatchCount
;
// Q/K/V/Y, dQ/dK/dV/dY, LSE
// Q/K/V/Y, dQ/dK/dV/dY, LSE
num_byte
+=
(
sizeof
(
InputDataType
)
*
M
*
K
+
sizeof
(
InputDataType
)
*
K
*
N
+
num_byte
+=
(
sizeof
(
InputDataType
)
*
M
*
K
+
sizeof
(
InputDataType
)
*
K
*
N
+
...
@@ -509,14 +513,16 @@ int run(int argc, char* argv[])
...
@@ -509,14 +513,16 @@ int run(int argc, char* argv[])
q_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_1
<
InputDataType
>
{
1
});
q_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_1
<
InputDataType
>
{
1
});
k_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
InputDataType
>
{});
k_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
InputDataType
>
{});
v_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
InputDataType
>
{});
v_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
InputDataType
>
{});
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_Sequential
<
2
>
{});
// dy[g0, g1, m, o]
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_Sequential
<
2
>
{});
// dy[g0, g1q, m, o]
// dO dot O = [0; 1; 2; ...]
// dO dot O = [0; 1; 2; ...]
break
;
break
;
case
6
:
case
6
:
q_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_1
<
InputDataType
>
{
1
});
q_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_1
<
InputDataType
>
{
1
});
k_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
InputDataType
>
{});
k_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
InputDataType
>
{});
v_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
InputDataType
>
{});
v_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
InputDataType
>
{});
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_Sequential
<
3
>
{});
// dy[g0, g1, m, o]
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_Sequential
<
3
>
{});
// dy[g0, g1q, m, o]
// assume mnko = 256
// assume mnko = 256
// P = softmax(QK) = 0.0039 * ones
// P = softmax(QK) = 0.0039 * ones
// O = P V = 0.0039 * ones
// O = P V = 0.0039 * ones
...
@@ -530,7 +536,7 @@ int run(int argc, char* argv[])
...
@@ -530,7 +536,7 @@ int run(int argc, char* argv[])
k_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
InputDataType
>
{});
k_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
InputDataType
>
{});
v_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
InputDataType
>
{});
v_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
InputDataType
>
{});
ygrad_gs_ms_os
.
GenerateTensorValue
(
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_1
<
InputDataType
>
{
1
});
// dy[g0, g1, m, o]
GeneratorTensor_1
<
InputDataType
>
{
1
});
// dy[g0, g1
q
, m, o]
// assume mnko = 256
// assume mnko = 256
// P = softmax(QK) = 0.0039 * ones
// P = softmax(QK) = 0.0039 * ones
// O = P V = 0.0039 * ones
// O = P V = 0.0039 * ones
...
@@ -551,21 +557,21 @@ int run(int argc, char* argv[])
...
@@ -551,21 +557,21 @@ int run(int argc, char* argv[])
Tensor
<
InputDataType
>
p_drop_g_m_n
({
BatchCount
,
M
,
N
});
Tensor
<
InputDataType
>
p_drop_g_m_n
({
BatchCount
,
M
,
N
});
q_gs_ms_ks
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
q_gs_ms_ks
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
q_g_m_k
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
q_g_m_k
(
idx
[
0
]
*
G1
Q
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
});
});
k_g_n_k
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
k_g_n_k
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
const
size_t
&
g0
=
idx
[
0
]
/
G1
;
const
size_t
&
g0
=
idx
[
0
]
/
G1
Q
;
const
size_t
&
g1
=
idx
[
0
]
%
G1
;
const
size_t
&
g1
q
=
idx
[
0
]
%
G1
Q
;
const
size_t
&
g
2
=
g1
/
h_ratio
;
const
size_t
&
g
1kv
=
g1
q
/
h_ratio
;
self
(
idx
)
=
k_gs_ns_ks
(
g0
,
g
2
,
idx
[
1
],
idx
[
2
]);
self
(
idx
)
=
k_gs_ns_ks
(
g0
,
g
1kv
,
idx
[
1
],
idx
[
2
]);
});
});
v_g_n_o
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
v_g_n_o
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
const
size_t
&
g0
=
idx
[
0
]
/
G1
;
const
size_t
&
g0
=
idx
[
0
]
/
G1
Q
;
const
size_t
&
g1
=
idx
[
0
]
%
G1
;
const
size_t
&
g1
q
=
idx
[
0
]
%
G1
Q
;
const
size_t
&
g
2
=
g1
/
h_ratio
;
const
size_t
&
g
1kv
=
g1
q
/
h_ratio
;
self
(
idx
)
=
v_gs_os_ns
(
g0
,
g
2
,
idx
[
2
],
idx
[
1
]);
self
(
idx
)
=
v_gs_os_ns
(
g0
,
g
1kv
,
idx
[
2
],
idx
[
1
]);
});
});
q_g_m_ks
.
push_back
(
q_g_m_k
);
q_g_m_ks
.
push_back
(
q_g_m_k
);
...
@@ -706,11 +712,11 @@ int run(int argc, char* argv[])
...
@@ -706,11 +712,11 @@ int run(int argc, char* argv[])
for
(
std
::
size_t
i
=
0
;
i
<
group_count
;
i
++
)
for
(
std
::
size_t
i
=
0
;
i
<
group_count
;
i
++
)
{
{
int
G1
=
q_tensors
[
i
].
GetLengths
()[
1
];
int
G1
Q
=
q_tensors
[
i
].
GetLengths
()[
1
];
// copy z matirx data form device
// copy z matirx data form device
z_tensors_device
[
i
]
->
FromDevice
(
z_tensors
[
i
].
mData
.
data
());
z_tensors_device
[
i
]
->
FromDevice
(
z_tensors
[
i
].
mData
.
data
());
z_tensors
[
i
].
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
z_tensors
[
i
].
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
z_g_m_ns
[
i
](
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
z_g_m_ns
[
i
](
idx
[
0
]
*
G1
Q
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
});
});
run_attention_fwd_host
(
q_g_m_ks
[
i
],
run_attention_fwd_host
(
q_g_m_ks
[
i
],
k_g_n_ks
[
i
],
k_g_n_ks
[
i
],
...
@@ -726,11 +732,11 @@ int run(int argc, char* argv[])
...
@@ -726,11 +732,11 @@ int run(int argc, char* argv[])
rp_dropout
);
rp_dropout
);
y_tensors
[
i
].
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
y_tensors
[
i
].
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
self
(
idx
)
=
y_g_m_os
[
i
](
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
]);
self
(
idx
)
=
y_g_m_os
[
i
](
idx
[
0
]
*
G1
Q
+
idx
[
1
],
idx
[
2
],
idx
[
3
]);
});
});
y_tensors_device
[
i
]
->
ToDevice
(
y_tensors
[
i
].
data
());
y_tensors_device
[
i
]
->
ToDevice
(
y_tensors
[
i
].
data
());
lse_tensors
[
i
].
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
lse_tensors
[
i
].
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
self
(
idx
)
=
lse_g_ms
[
i
](
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
]);
self
(
idx
)
=
lse_g_ms
[
i
](
idx
[
0
]
*
G1
Q
+
idx
[
1
],
idx
[
2
]);
});
});
lse_tensors_device
[
i
]
->
ToDevice
(
lse_tensors
[
i
].
data
());
lse_tensors_device
[
i
]
->
ToDevice
(
lse_tensors
[
i
].
data
());
qgrad_tensors_device
[
i
]
->
SetZero
();
qgrad_tensors_device
[
i
]
->
SetZero
();
...
@@ -744,12 +750,12 @@ int run(int argc, char* argv[])
...
@@ -744,12 +750,12 @@ int run(int argc, char* argv[])
{
{
int
G0
=
q_tensors
[
i
].
GetLengths
()[
0
];
int
G0
=
q_tensors
[
i
].
GetLengths
()[
0
];
int
G1
=
q_tensors
[
i
].
GetLengths
()[
1
];
int
G1
Q
=
q_tensors
[
i
].
GetLengths
()[
1
];
int
O
=
v_tensors
[
i
].
GetLengths
()[
2
];
int
O
=
v_tensors
[
i
].
GetLengths
()[
2
];
int
N
=
v_tensors
[
i
].
GetLengths
()[
3
];
int
N
=
v_tensors
[
i
].
GetLengths
()[
3
];
int
M
=
q_tensors
[
i
].
GetLengths
()[
2
];
int
M
=
q_tensors
[
i
].
GetLengths
()[
2
];
int
K
=
q_tensors
[
i
].
GetLengths
()[
3
];
int
K
=
q_tensors
[
i
].
GetLengths
()[
3
];
int
BatchCount
=
G0
*
G1
;
int
BatchCount
=
G0
*
G1
Q
;
Tensor
<
OutputDataType
>
qgrad_g_m_k
({
BatchCount
,
M
,
K
});
Tensor
<
OutputDataType
>
qgrad_g_m_k
({
BatchCount
,
M
,
K
});
Tensor
<
OutputDataType
>
kgrad_g_n_k
({
BatchCount
,
N
,
K
});
Tensor
<
OutputDataType
>
kgrad_g_n_k
({
BatchCount
,
N
,
K
});
Tensor
<
OutputDataType
>
vgrad_g_n_o
({
BatchCount
,
N
,
O
});
Tensor
<
OutputDataType
>
vgrad_g_n_o
({
BatchCount
,
N
,
O
});
...
@@ -759,7 +765,7 @@ int run(int argc, char* argv[])
...
@@ -759,7 +765,7 @@ int run(int argc, char* argv[])
Tensor
<
InputDataType
>
ygrad_g_m_o
({
BatchCount
,
M
,
O
});
Tensor
<
InputDataType
>
ygrad_g_m_o
({
BatchCount
,
M
,
O
});
ygrad_tensors
[
i
].
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
ygrad_tensors
[
i
].
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
ygrad_g_m_o
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
ygrad_g_m_o
(
idx
[
0
]
*
G1
Q
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
});
});
auto
ref_gemm0_grad
=
ReferenceGemm0GradInstance
{};
auto
ref_gemm0_grad
=
ReferenceGemm0GradInstance
{};
auto
ref_gemm0_grad_invoker
=
ref_gemm0_grad
.
MakeInvoker
();
auto
ref_gemm0_grad_invoker
=
ref_gemm0_grad
.
MakeInvoker
();
...
@@ -819,26 +825,26 @@ int run(int argc, char* argv[])
...
@@ -819,26 +825,26 @@ int run(int argc, char* argv[])
vgrad_tensors_device
[
i
]
->
FromDevice
(
vgrad_gs_os_ns_device_result
.
data
());
vgrad_tensors_device
[
i
]
->
FromDevice
(
vgrad_gs_os_ns_device_result
.
data
());
// permute
// permute
qgrad_gs_ms_ks_host_result
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
qgrad_gs_ms_ks_host_result
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
const
size_t
&
g0
=
idx
[
0
];
const
size_t
&
g0
=
idx
[
0
];
const
size_t
&
g1
=
idx
[
1
];
const
size_t
&
g1
q
=
idx
[
1
];
const
size_t
g
=
g0
*
G1
+
g1
;
const
size_t
g
=
g0
*
G1
Q
+
g1
q
;
self
(
idx
)
=
qgrad_g_m_k
(
g
,
idx
[
2
],
idx
[
3
]);
self
(
idx
)
=
qgrad_g_m_k
(
g
,
idx
[
2
],
idx
[
3
]);
});
});
kgrad_gs_ns_ks_host_result
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
kgrad_gs_ns_ks_host_result
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
const
size_t
&
g0
=
idx
[
0
];
const
size_t
&
g0
=
idx
[
0
];
const
size_t
&
g1
=
idx
[
1
];
const
size_t
&
g1
q
=
idx
[
1
];
const
size_t
g
=
g0
*
G1
+
g1
;
const
size_t
g
=
g0
*
G1
Q
+
g1
q
;
self
(
idx
)
=
kgrad_g_n_k
(
g
,
idx
[
2
],
idx
[
3
]);
self
(
idx
)
=
kgrad_g_n_k
(
g
,
idx
[
2
],
idx
[
3
]);
});
});
vgrad_gs_os_ns_host_result
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
vgrad_gs_os_ns_host_result
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
const
size_t
&
g0
=
idx
[
0
];
const
size_t
&
g0
=
idx
[
0
];
const
size_t
&
g1
=
idx
[
1
];
const
size_t
&
g1
q
=
idx
[
1
];
const
size_t
g
=
g0
*
G1
+
g1
;
const
size_t
g
=
g0
*
G1
Q
+
g1
q
;
self
(
idx
)
=
vgrad_g_n_o
(
g
,
idx
[
3
],
idx
[
2
]);
self
(
idx
)
=
vgrad_g_n_o
(
g
,
idx
[
3
],
idx
[
2
]);
});
});
...
...
example/32_batched_gemm_scale_softmax_gemm/grouped_multihead_attention_backward_v3.cpp
View file @
10836d41
...
@@ -269,11 +269,11 @@ int run(int argc, char* argv[])
...
@@ -269,11 +269,11 @@ int run(int argc, char* argv[])
// Overall QKV matrices shape
// Overall QKV matrices shape
// y_g_m_o = Softmax(alpha * Q_g_m_k * K_g_k_n) * V_g_n_o
// y_g_m_o = Softmax(alpha * Q_g_m_k * K_g_k_n) * V_g_n_o
// y_g0_g1_m_o = reshape(y_g_m_o, [G0, G1, M, O])
// y_g0_g1
q
_m_o = reshape(y_g_m_o, [G0, G1
Q
, M, O])
// y_g0_m_g1_o = permute(y_g0_g1_m_o, [0, 2, 1, 3])
// y_g0_m_g1
q
_o = permute(y_g0_g1
q
_m_o, [0, 2, 1, 3])
float
alpha
=
1.
f
/
std
::
sqrt
(
DIM
);
float
alpha
=
1.
f
/
std
::
sqrt
(
DIM
);
float
p_drop
=
0.0
;
float
p_drop
=
0.0
;
int
h_ratio
=
1
;
// G1 / G
2
int
h_ratio
=
1
;
// G1
Q
/ G
1KV
bool
input_permute
=
true
;
bool
input_permute
=
true
;
bool
output_permute
=
true
;
bool
output_permute
=
true
;
...
@@ -373,61 +373,65 @@ int run(int argc, char* argv[])
...
@@ -373,61 +373,65 @@ int run(int argc, char* argv[])
std
::
size_t
flop
=
0
,
num_byte
=
0
;
std
::
size_t
flop
=
0
,
num_byte
=
0
;
for
(
std
::
size_t
i
=
0
;
i
<
group_count
;
i
++
)
for
(
std
::
size_t
i
=
0
;
i
<
group_count
;
i
++
)
{
{
int
M
=
128
*
(
rand
()
%
8
)
+
(
rand
()
%
128
);
int
M
=
128
*
(
rand
()
%
8
)
+
(
rand
()
%
128
);
int
N
=
128
*
(
rand
()
%
8
)
+
(
rand
()
%
128
);
int
N
=
128
*
(
rand
()
%
8
)
+
(
rand
()
%
128
);
int
K
=
DIM
;
int
K
=
DIM
;
int
O
=
DIM
;
int
O
=
DIM
;
int
G0
=
rand
()
%
4
+
1
;
int
G0
=
rand
()
%
4
+
1
;
int
G
2
=
rand
()
%
4
+
1
;
int
G
1KV
=
rand
()
%
4
+
1
;
int
G1
=
G
2
*
h_ratio
;
int
G1
Q
=
G
1KV
*
h_ratio
;
std
::
vector
<
ck
::
index_t
>
q_gs_ms_ks_lengths
{
G0
,
G1
,
M
,
K
};
std
::
vector
<
ck
::
index_t
>
q_gs_ms_ks_lengths
{
G0
,
G1
Q
,
M
,
K
};
std
::
vector
<
ck
::
index_t
>
q_gs_ms_ks_strides
=
std
::
vector
<
ck
::
index_t
>
q_gs_ms_ks_strides
=
input_permute
input_permute
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
*
K
,
K
,
G1
*
K
,
1
}
// Q layout [G0, M, G1, K]
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
Q
*
K
,
K
,
G1
Q
*
K
,
1
}
// Q layout [G0, M, G1
Q
, K]
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
M
*
K
,
M
*
K
,
K
,
1
};
// Q layout [G0, G1, M, K]
:
std
::
vector
<
ck
::
index_t
>
{
G1
Q
*
M
*
K
,
M
*
K
,
K
,
1
};
// Q layout [G0, G1
Q
, M, K]
std
::
vector
<
ck
::
index_t
>
k_gs_ns_ks_lengths
{
G0
,
G
2
,
N
,
K
};
std
::
vector
<
ck
::
index_t
>
k_gs_ns_ks_lengths
{
G0
,
G
1KV
,
N
,
K
};
std
::
vector
<
ck
::
index_t
>
k_gs_ns_ks_strides
=
std
::
vector
<
ck
::
index_t
>
k_gs_ns_ks_strides
=
input_permute
input_permute
?
std
::
vector
<
ck
::
index_t
>
{
N
*
G2
*
K
,
K
,
G2
*
K
,
1
}
// K layout [G0, N, G2, K]
?
std
::
vector
<
ck
::
index_t
>
{
N
*
G1KV
*
K
,
K
,
G1KV
*
K
,
1
}
:
std
::
vector
<
ck
::
index_t
>
{
G2
*
N
*
K
,
N
*
K
,
K
,
1
};
// K layout [G0, G2, N, K]
// K layout [G0, N, G1KV, K]
:
std
::
vector
<
ck
::
index_t
>
{
G1KV
*
N
*
K
,
N
*
K
,
K
,
1
};
// K layout [G0, G1KV, N, K]
std
::
vector
<
ck
::
index_t
>
v_gs_os_ns_lengths
{
G0
,
G
2
,
O
,
N
};
std
::
vector
<
ck
::
index_t
>
v_gs_os_ns_lengths
{
G0
,
G
1KV
,
O
,
N
};
std
::
vector
<
ck
::
index_t
>
v_gs_os_ns_strides
=
std
::
vector
<
ck
::
index_t
>
v_gs_os_ns_strides
=
input_permute
input_permute
?
std
::
vector
<
ck
::
index_t
>
{
N
*
G2
*
O
,
O
,
1
,
G2
*
O
}
// V layout [G0, N, G2, O]
?
std
::
vector
<
ck
::
index_t
>
{
N
*
G1KV
*
O
,
O
,
1
,
G1KV
*
O
}
:
std
::
vector
<
ck
::
index_t
>
{
G2
*
N
*
O
,
N
*
O
,
1
,
O
};
// V layout [G0, G2, N, O]
// V layout [G0, N, G1KV, O]
:
std
::
vector
<
ck
::
index_t
>
{
G1KV
*
N
*
O
,
N
*
O
,
1
,
O
};
// V layout [G0, G1KV, N, O]
std
::
vector
<
ck
::
index_t
>
y_gs_ms_os_lengths
{
G0
,
G1
,
M
,
O
};
std
::
vector
<
ck
::
index_t
>
y_gs_ms_os_lengths
{
G0
,
G1
Q
,
M
,
O
};
std
::
vector
<
ck
::
index_t
>
y_gs_ms_os_strides
=
std
::
vector
<
ck
::
index_t
>
y_gs_ms_os_strides
=
output_permute
output_permute
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
*
O
,
O
,
G1
*
O
,
1
}
// Y layout [G0, M, G1, O]
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
Q
*
O
,
O
,
G1
Q
*
O
,
1
}
// Y layout [G0, M, G1
Q
, O]
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
M
*
O
,
M
*
O
,
O
,
1
};
// Y layout [G0, G1, M, O]
:
std
::
vector
<
ck
::
index_t
>
{
G1
Q
*
M
*
O
,
M
*
O
,
O
,
1
};
// Y layout [G0, G1
Q
, M, O]
std
::
vector
<
ck
::
index_t
>
z_gs_ms_ns_lengths
{
G0
,
G1
,
M
,
N
};
std
::
vector
<
ck
::
index_t
>
z_gs_ms_ns_lengths
{
G0
,
G1
Q
,
M
,
N
};
std
::
vector
<
ck
::
index_t
>
z_gs_ms_ns_strides
=
std
::
vector
<
ck
::
index_t
>
z_gs_ms_ns_strides
=
input_permute
input_permute
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
*
N
,
N
,
G1
*
N
,
1
}
// Z layout [G0, M, G1, N]
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
Q
*
N
,
N
,
G1
Q
*
N
,
1
}
// Z layout [G0, M, G1
Q
, N]
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
M
*
N
,
M
*
N
,
N
,
1
};
// Z layout [G0, G1, M, N]
:
std
::
vector
<
ck
::
index_t
>
{
G1
Q
*
M
*
N
,
M
*
N
,
N
,
1
};
// Z layout [G0, G1
Q
, M, N]
std
::
vector
<
ck
::
index_t
>
kgrad_gs_ns_ks_lengths
{
G0
,
G1
,
N
,
K
};
std
::
vector
<
ck
::
index_t
>
kgrad_gs_ns_ks_lengths
{
G0
,
G1
Q
,
N
,
K
};
std
::
vector
<
ck
::
index_t
>
kgrad_gs_ns_ks_strides
=
std
::
vector
<
ck
::
index_t
>
kgrad_gs_ns_ks_strides
=
input_permute
input_permute
?
std
::
vector
<
ck
::
index_t
>
{
N
*
G1Q
*
K
,
K
,
G1Q
*
K
,
1
}
?
std
::
vector
<
ck
::
index_t
>
{
N
*
G1
*
K
,
K
,
G1
*
K
,
1
}
// KGrad layout [G0, N, G1, K]
// KGrad layout [G0, N, G1Q, K]
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
N
*
K
,
N
*
K
,
K
,
1
};
// KGrad layout [G0, G1, N, K]
:
std
::
vector
<
ck
::
index_t
>
{
G1Q
*
N
*
K
,
N
*
K
,
K
,
1
};
// KGrad layout [G0, G1Q, N, K]
std
::
vector
<
ck
::
index_t
>
vgrad_gs_os_ns_lengths
{
G0
,
G1
,
O
,
N
};
std
::
vector
<
ck
::
index_t
>
vgrad_gs_os_ns_lengths
{
G0
,
G1
Q
,
O
,
N
};
std
::
vector
<
ck
::
index_t
>
vgrad_gs_os_ns_strides
=
std
::
vector
<
ck
::
index_t
>
vgrad_gs_os_ns_strides
=
input_permute
input_permute
?
std
::
vector
<
ck
::
index_t
>
{
N
*
G1Q
*
O
,
O
,
1
,
G1Q
*
O
}
?
std
::
vector
<
ck
::
index_t
>
{
N
*
G1
*
O
,
O
,
1
,
G1
*
O
}
// VGrad layout [G0, N, G1, O]
// VGrad layout [G0, N, G1Q, O]
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
N
*
O
,
N
*
O
,
1
,
O
};
// VGrad layout [G0, G1, N, O]
:
std
::
vector
<
ck
::
index_t
>
{
G1Q
*
N
*
O
,
N
*
O
,
1
,
O
};
// VGrad layout [G0, G1Q, N, O]
// The softmax stat log-sum-exp (LSE) is used to speed up softmax calculation in backward
// The softmax stat log-sum-exp (LSE) is used to speed up softmax calculation in backward
// pass Pi = exp(Si) / sum(exp(S0) + exp(S1) + ...)
// pass Pi = exp(Si) / sum(exp(S0) + exp(S1) + ...)
// = exp(Si) / exp(log(sum(exp() + ...)))
// = exp(Si) / exp(log(sum(exp() + ...)))
// = exp(Si - log(sum(exp() + ...)))
// = exp(Si - log(sum(exp() + ...)))
// ^^^^^^^^^^^^^^^^^^^^^
// ^^^^^^^^^^^^^^^^^^^^^
// LSE
// LSE
std
::
vector
<
ck
::
index_t
>
lse_gs_ms_lengths
{
G0
,
G1
,
M
};
std
::
vector
<
ck
::
index_t
>
lse_gs_ms_lengths
{
G0
,
G1
Q
,
M
};
std
::
vector
<
ck
::
index_t
>
lse_gs_ms_strides
{
G1
*
M
,
M
,
1
};
// LSE layout [G0, G1, M]
std
::
vector
<
ck
::
index_t
>
lse_gs_ms_strides
{
G1
Q
*
M
,
M
,
1
};
// LSE layout [G0, G1
Q
, M]
problem_descs
.
push_back
({
problem_descs
.
push_back
({
q_gs_ms_ks_lengths
,
q_gs_ms_ks_lengths
,
q_gs_ms_ks_strides
,
q_gs_ms_ks_strides
,
...
@@ -451,7 +455,7 @@ int run(int argc, char* argv[])
...
@@ -451,7 +455,7 @@ int run(int argc, char* argv[])
{},
// std::array<std::vector<ck::index_t>, 1>{acc1_biases_gs_ms_os_strides},
{},
// std::array<std::vector<ck::index_t>, 1>{acc1_biases_gs_ms_os_strides},
});
});
int
BatchCount
=
G0
*
G1
;
int
BatchCount
=
G0
*
G1
Q
;
flop
+=
(
size_t
(
3
)
*
M
*
N
*
K
+
size_t
(
2
)
*
M
*
N
*
O
)
*
2
*
BatchCount
;
flop
+=
(
size_t
(
3
)
*
M
*
N
*
K
+
size_t
(
2
)
*
M
*
N
*
O
)
*
2
*
BatchCount
;
// Q/K/V/Y, dQ/dK/dV/dY, LSE
// Q/K/V/Y, dQ/dK/dV/dY, LSE
num_byte
+=
(
sizeof
(
InputDataType
)
*
M
*
K
+
sizeof
(
InputDataType
)
*
K
*
N
+
num_byte
+=
(
sizeof
(
InputDataType
)
*
M
*
K
+
sizeof
(
InputDataType
)
*
K
*
N
+
...
@@ -515,14 +519,16 @@ int run(int argc, char* argv[])
...
@@ -515,14 +519,16 @@ int run(int argc, char* argv[])
q_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_1
<
InputDataType
>
{
1
});
q_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_1
<
InputDataType
>
{
1
});
k_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
InputDataType
>
{});
k_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
InputDataType
>
{});
v_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
InputDataType
>
{});
v_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
InputDataType
>
{});
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_Sequential
<
2
>
{});
// dy[g0, g1, m, o]
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_Sequential
<
2
>
{});
// dy[g0, g1q, m, o]
// dO dot O = [0; 1; 2; ...]
// dO dot O = [0; 1; 2; ...]
break
;
break
;
case
6
:
case
6
:
q_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_1
<
InputDataType
>
{
1
});
q_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_1
<
InputDataType
>
{
1
});
k_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
InputDataType
>
{});
k_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
InputDataType
>
{});
v_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
InputDataType
>
{});
v_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
InputDataType
>
{});
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_Sequential
<
3
>
{});
// dy[g0, g1, m, o]
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_Sequential
<
3
>
{});
// dy[g0, g1q, m, o]
// assume mnko = 256
// assume mnko = 256
// P = softmax(QK) = 0.0039 * ones
// P = softmax(QK) = 0.0039 * ones
// O = P V = 0.0039 * ones
// O = P V = 0.0039 * ones
...
@@ -536,7 +542,7 @@ int run(int argc, char* argv[])
...
@@ -536,7 +542,7 @@ int run(int argc, char* argv[])
k_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
InputDataType
>
{});
k_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
InputDataType
>
{});
v_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
InputDataType
>
{});
v_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
InputDataType
>
{});
ygrad_gs_ms_os
.
GenerateTensorValue
(
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_1
<
InputDataType
>
{
1
});
// dy[g0, g1, m, o]
GeneratorTensor_1
<
InputDataType
>
{
1
});
// dy[g0, g1
q
, m, o]
// assume mnko = 256
// assume mnko = 256
// P = softmax(QK) = 0.0039 * ones
// P = softmax(QK) = 0.0039 * ones
// O = P V = 0.0039 * ones
// O = P V = 0.0039 * ones
...
@@ -558,21 +564,21 @@ int run(int argc, char* argv[])
...
@@ -558,21 +564,21 @@ int run(int argc, char* argv[])
Tensor
<
InputDataType
>
p_drop_g_m_n
({
BatchCount
,
M
,
N
});
Tensor
<
InputDataType
>
p_drop_g_m_n
({
BatchCount
,
M
,
N
});
q_gs_ms_ks
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
q_gs_ms_ks
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
q_g_m_k
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
q_g_m_k
(
idx
[
0
]
*
G1
Q
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
});
});
k_g_n_k
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
k_g_n_k
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
const
size_t
&
g0
=
idx
[
0
]
/
G1
;
const
size_t
&
g0
=
idx
[
0
]
/
G1
Q
;
const
size_t
&
g1
=
idx
[
0
]
%
G1
;
const
size_t
&
g1
q
=
idx
[
0
]
%
G1
Q
;
const
size_t
&
g
2
=
g1
/
h_ratio
;
const
size_t
&
g
1kv
=
g1
q
/
h_ratio
;
self
(
idx
)
=
k_gs_ns_ks
(
g0
,
g
2
,
idx
[
1
],
idx
[
2
]);
self
(
idx
)
=
k_gs_ns_ks
(
g0
,
g
1kv
,
idx
[
1
],
idx
[
2
]);
});
});
v_g_n_o
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
v_g_n_o
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
const
size_t
&
g0
=
idx
[
0
]
/
G1
;
const
size_t
&
g0
=
idx
[
0
]
/
G1
Q
;
const
size_t
&
g1
=
idx
[
0
]
%
G1
;
const
size_t
&
g1
q
=
idx
[
0
]
%
G1
Q
;
const
size_t
&
g
2
=
g1
/
h_ratio
;
const
size_t
&
g
1kv
=
g1
q
/
h_ratio
;
self
(
idx
)
=
v_gs_os_ns
(
g0
,
g
2
,
idx
[
2
],
idx
[
1
]);
self
(
idx
)
=
v_gs_os_ns
(
g0
,
g
1kv
,
idx
[
2
],
idx
[
1
]);
});
});
q_g_m_ks
.
push_back
(
q_g_m_k
);
q_g_m_ks
.
push_back
(
q_g_m_k
);
...
@@ -719,11 +725,11 @@ int run(int argc, char* argv[])
...
@@ -719,11 +725,11 @@ int run(int argc, char* argv[])
for
(
std
::
size_t
i
=
0
;
i
<
group_count
;
i
++
)
for
(
std
::
size_t
i
=
0
;
i
<
group_count
;
i
++
)
{
{
int
G1
=
q_tensors
[
i
].
GetLengths
()[
1
];
int
G1
Q
=
q_tensors
[
i
].
GetLengths
()[
1
];
// copy z matirx data form device
// copy z matirx data form device
z_tensors_device
[
i
]
->
FromDevice
(
z_tensors
[
i
].
mData
.
data
());
z_tensors_device
[
i
]
->
FromDevice
(
z_tensors
[
i
].
mData
.
data
());
z_tensors
[
i
].
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
z_tensors
[
i
].
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
z_g_m_ns
[
i
](
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
z_g_m_ns
[
i
](
idx
[
0
]
*
G1
Q
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
});
});
run_attention_fwd_host
(
q_g_m_ks
[
i
],
run_attention_fwd_host
(
q_g_m_ks
[
i
],
k_g_n_ks
[
i
],
k_g_n_ks
[
i
],
...
@@ -739,11 +745,11 @@ int run(int argc, char* argv[])
...
@@ -739,11 +745,11 @@ int run(int argc, char* argv[])
rp_dropout
);
rp_dropout
);
y_tensors
[
i
].
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
y_tensors
[
i
].
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
self
(
idx
)
=
y_g_m_os
[
i
](
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
]);
self
(
idx
)
=
y_g_m_os
[
i
](
idx
[
0
]
*
G1
Q
+
idx
[
1
],
idx
[
2
],
idx
[
3
]);
});
});
y_tensors_device
[
i
]
->
ToDevice
(
y_tensors
[
i
].
data
());
y_tensors_device
[
i
]
->
ToDevice
(
y_tensors
[
i
].
data
());
lse_tensors
[
i
].
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
lse_tensors
[
i
].
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
self
(
idx
)
=
lse_g_ms
[
i
](
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
]);
self
(
idx
)
=
lse_g_ms
[
i
](
idx
[
0
]
*
G1
Q
+
idx
[
1
],
idx
[
2
]);
});
});
lse_tensors_device
[
i
]
->
ToDevice
(
lse_tensors
[
i
].
data
());
lse_tensors_device
[
i
]
->
ToDevice
(
lse_tensors
[
i
].
data
());
qgrad_tensors_device
[
i
]
->
SetZero
();
qgrad_tensors_device
[
i
]
->
SetZero
();
...
@@ -757,12 +763,12 @@ int run(int argc, char* argv[])
...
@@ -757,12 +763,12 @@ int run(int argc, char* argv[])
{
{
int
G0
=
q_tensors
[
i
].
GetLengths
()[
0
];
int
G0
=
q_tensors
[
i
].
GetLengths
()[
0
];
int
G1
=
q_tensors
[
i
].
GetLengths
()[
1
];
int
G1
Q
=
q_tensors
[
i
].
GetLengths
()[
1
];
int
O
=
v_tensors
[
i
].
GetLengths
()[
2
];
int
O
=
v_tensors
[
i
].
GetLengths
()[
2
];
int
N
=
v_tensors
[
i
].
GetLengths
()[
3
];
int
N
=
v_tensors
[
i
].
GetLengths
()[
3
];
int
M
=
q_tensors
[
i
].
GetLengths
()[
2
];
int
M
=
q_tensors
[
i
].
GetLengths
()[
2
];
int
K
=
q_tensors
[
i
].
GetLengths
()[
3
];
int
K
=
q_tensors
[
i
].
GetLengths
()[
3
];
int
BatchCount
=
G0
*
G1
;
int
BatchCount
=
G0
*
G1
Q
;
Tensor
<
OutputDataType
>
qgrad_g_m_k
({
BatchCount
,
M
,
K
});
Tensor
<
OutputDataType
>
qgrad_g_m_k
({
BatchCount
,
M
,
K
});
Tensor
<
OutputDataType
>
kgrad_g_n_k
({
BatchCount
,
N
,
K
});
Tensor
<
OutputDataType
>
kgrad_g_n_k
({
BatchCount
,
N
,
K
});
Tensor
<
OutputDataType
>
vgrad_g_n_o
({
BatchCount
,
N
,
O
});
Tensor
<
OutputDataType
>
vgrad_g_n_o
({
BatchCount
,
N
,
O
});
...
@@ -772,7 +778,7 @@ int run(int argc, char* argv[])
...
@@ -772,7 +778,7 @@ int run(int argc, char* argv[])
Tensor
<
InputDataType
>
ygrad_g_m_o
({
BatchCount
,
M
,
O
});
Tensor
<
InputDataType
>
ygrad_g_m_o
({
BatchCount
,
M
,
O
});
ygrad_tensors
[
i
].
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
ygrad_tensors
[
i
].
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
ygrad_g_m_o
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
ygrad_g_m_o
(
idx
[
0
]
*
G1
Q
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
});
});
auto
ref_gemm0_grad
=
ReferenceGemm0GradInstance
{};
auto
ref_gemm0_grad
=
ReferenceGemm0GradInstance
{};
auto
ref_gemm0_grad_invoker
=
ref_gemm0_grad
.
MakeInvoker
();
auto
ref_gemm0_grad_invoker
=
ref_gemm0_grad
.
MakeInvoker
();
...
@@ -832,26 +838,26 @@ int run(int argc, char* argv[])
...
@@ -832,26 +838,26 @@ int run(int argc, char* argv[])
vgrad_tensors_device
[
i
]
->
FromDevice
(
vgrad_gs_os_ns_device_result
.
data
());
vgrad_tensors_device
[
i
]
->
FromDevice
(
vgrad_gs_os_ns_device_result
.
data
());
// permute
// permute
qgrad_gs_ms_ks_host_result
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
qgrad_gs_ms_ks_host_result
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
const
size_t
&
g0
=
idx
[
0
];
const
size_t
&
g0
=
idx
[
0
];
const
size_t
&
g1
=
idx
[
1
];
const
size_t
&
g1
q
=
idx
[
1
];
const
size_t
g
=
g0
*
G1
+
g1
;
const
size_t
g
=
g0
*
G1
Q
+
g1
q
;
self
(
idx
)
=
qgrad_g_m_k
(
g
,
idx
[
2
],
idx
[
3
]);
self
(
idx
)
=
qgrad_g_m_k
(
g
,
idx
[
2
],
idx
[
3
]);
});
});
kgrad_gs_ns_ks_host_result
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
kgrad_gs_ns_ks_host_result
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
const
size_t
&
g0
=
idx
[
0
];
const
size_t
&
g0
=
idx
[
0
];
const
size_t
&
g1
=
idx
[
1
];
const
size_t
&
g1
q
=
idx
[
1
];
const
size_t
g
=
g0
*
G1
+
g1
;
const
size_t
g
=
g0
*
G1
Q
+
g1
q
;
self
(
idx
)
=
kgrad_g_n_k
(
g
,
idx
[
2
],
idx
[
3
]);
self
(
idx
)
=
kgrad_g_n_k
(
g
,
idx
[
2
],
idx
[
3
]);
});
});
vgrad_gs_os_ns_host_result
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
vgrad_gs_os_ns_host_result
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
const
size_t
&
g0
=
idx
[
0
];
const
size_t
&
g0
=
idx
[
0
];
const
size_t
&
g1
=
idx
[
1
];
const
size_t
&
g1
q
=
idx
[
1
];
const
size_t
g
=
g0
*
G1
+
g1
;
const
size_t
g
=
g0
*
G1
Q
+
g1
q
;
self
(
idx
)
=
vgrad_g_n_o
(
g
,
idx
[
3
],
idx
[
2
]);
self
(
idx
)
=
vgrad_g_n_o
(
g
,
idx
[
3
],
idx
[
2
]);
});
});
...
...
example/32_batched_gemm_scale_softmax_gemm/grouped_multihead_attention_train_v2.cpp
View file @
10836d41
...
@@ -298,11 +298,11 @@ int run(int argc, char* argv[])
...
@@ -298,11 +298,11 @@ int run(int argc, char* argv[])
// Overall QKV matrices shape
// Overall QKV matrices shape
// y_g_m_o = Softmax(alpha * Q_g_m_k * K_g_k_n) * V_g_n_o
// y_g_m_o = Softmax(alpha * Q_g_m_k * K_g_k_n) * V_g_n_o
// y_g0_g1_m_o = reshape(y_g_m_o, [G0, G1, M, O])
// y_g0_g1
q
_m_o = reshape(y_g_m_o, [G0, G1
Q
, M, O])
// y_g0_m_g1_o = permute(y_g0_g1_m_o, [0, 2, 1, 3])
// y_g0_m_g1
q
_o = permute(y_g0_g1
q
_m_o, [0, 2, 1, 3])
float
alpha
=
1.
f
/
std
::
sqrt
(
DIM
);
float
alpha
=
1.
f
/
std
::
sqrt
(
DIM
);
float
p_drop
=
0.2
;
float
p_drop
=
0.2
;
int
h_ratio
=
1
;
// G1 / G
2
int
h_ratio
=
1
;
// G1
Q
/ G
1KV
bool
input_permute
=
true
;
bool
input_permute
=
true
;
bool
output_permute
=
true
;
bool
output_permute
=
true
;
...
@@ -409,61 +409,65 @@ int run(int argc, char* argv[])
...
@@ -409,61 +409,65 @@ int run(int argc, char* argv[])
std
::
size_t
flop_bwd
=
0
,
num_byte_bwd
=
0
;
std
::
size_t
flop_bwd
=
0
,
num_byte_bwd
=
0
;
for
(
std
::
size_t
i
=
0
;
i
<
group_count
;
i
++
)
for
(
std
::
size_t
i
=
0
;
i
<
group_count
;
i
++
)
{
{
int
M
=
128
*
(
rand
()
%
8
)
+
(
rand
()
%
128
);
int
M
=
128
*
(
rand
()
%
8
)
+
(
rand
()
%
128
);
int
N
=
128
*
(
rand
()
%
8
)
+
(
rand
()
%
128
);
int
N
=
128
*
(
rand
()
%
8
)
+
(
rand
()
%
128
);
int
K
=
DIM
;
int
K
=
DIM
;
int
O
=
DIM
;
int
O
=
DIM
;
int
G0
=
rand
()
%
4
+
1
;
int
G0
=
rand
()
%
4
+
1
;
int
G
2
=
rand
()
%
4
+
1
;
int
G
1KV
=
rand
()
%
4
+
1
;
int
G1
=
G
2
*
h_ratio
;
int
G1
Q
=
G
1KV
*
h_ratio
;
std
::
vector
<
ck
::
index_t
>
q_gs_ms_ks_lengths
{
G0
,
G1
,
M
,
K
};
std
::
vector
<
ck
::
index_t
>
q_gs_ms_ks_lengths
{
G0
,
G1
Q
,
M
,
K
};
std
::
vector
<
ck
::
index_t
>
q_gs_ms_ks_strides
=
std
::
vector
<
ck
::
index_t
>
q_gs_ms_ks_strides
=
input_permute
input_permute
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
*
K
,
K
,
G1
*
K
,
1
}
// Q layout [G0, M, G1, K]
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
Q
*
K
,
K
,
G1
Q
*
K
,
1
}
// Q layout [G0, M, G1
Q
, K]
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
M
*
K
,
M
*
K
,
K
,
1
};
// Q layout [G0, G1, M, K]
:
std
::
vector
<
ck
::
index_t
>
{
G1
Q
*
M
*
K
,
M
*
K
,
K
,
1
};
// Q layout [G0, G1
Q
, M, K]
std
::
vector
<
ck
::
index_t
>
k_gs_ns_ks_lengths
{
G0
,
G
2
,
N
,
K
};
std
::
vector
<
ck
::
index_t
>
k_gs_ns_ks_lengths
{
G0
,
G
1KV
,
N
,
K
};
std
::
vector
<
ck
::
index_t
>
k_gs_ns_ks_strides
=
std
::
vector
<
ck
::
index_t
>
k_gs_ns_ks_strides
=
input_permute
input_permute
?
std
::
vector
<
ck
::
index_t
>
{
N
*
G2
*
K
,
K
,
G2
*
K
,
1
}
// K layout [G0, N, G2, K]
?
std
::
vector
<
ck
::
index_t
>
{
N
*
G1KV
*
K
,
K
,
G1KV
*
K
,
1
}
:
std
::
vector
<
ck
::
index_t
>
{
G2
*
N
*
K
,
N
*
K
,
K
,
1
};
// K layout [G0, G2, N, K]
// K layout [G0, N, G1KV, K]
:
std
::
vector
<
ck
::
index_t
>
{
G1KV
*
N
*
K
,
N
*
K
,
K
,
1
};
// K layout [G0, G1KV, N, K]
std
::
vector
<
ck
::
index_t
>
v_gs_os_ns_lengths
{
G0
,
G
2
,
O
,
N
};
std
::
vector
<
ck
::
index_t
>
v_gs_os_ns_lengths
{
G0
,
G
1KV
,
O
,
N
};
std
::
vector
<
ck
::
index_t
>
v_gs_os_ns_strides
=
std
::
vector
<
ck
::
index_t
>
v_gs_os_ns_strides
=
input_permute
input_permute
?
std
::
vector
<
ck
::
index_t
>
{
N
*
G2
*
O
,
O
,
1
,
G2
*
O
}
// V layout [G0, N, G2, O]
?
std
::
vector
<
ck
::
index_t
>
{
N
*
G1KV
*
O
,
O
,
1
,
G1KV
*
O
}
:
std
::
vector
<
ck
::
index_t
>
{
G2
*
N
*
O
,
N
*
O
,
1
,
O
};
// V layout [G0, G2, N, O]
// V layout [G0, N, G1KV, O]
:
std
::
vector
<
ck
::
index_t
>
{
G1KV
*
N
*
O
,
N
*
O
,
1
,
O
};
// V layout [G0, G1KV, N, O]
std
::
vector
<
ck
::
index_t
>
y_gs_ms_os_lengths
{
G0
,
G1
,
M
,
O
};
std
::
vector
<
ck
::
index_t
>
y_gs_ms_os_lengths
{
G0
,
G1
Q
,
M
,
O
};
std
::
vector
<
ck
::
index_t
>
y_gs_ms_os_strides
=
std
::
vector
<
ck
::
index_t
>
y_gs_ms_os_strides
=
output_permute
output_permute
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
*
O
,
O
,
G1
*
O
,
1
}
// Y layout [G0, M, G1, O]
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
Q
*
O
,
O
,
G1
Q
*
O
,
1
}
// Y layout [G0, M, G1
Q
, O]
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
M
*
O
,
M
*
O
,
O
,
1
};
// Y layout [G0, G1, M, O]
:
std
::
vector
<
ck
::
index_t
>
{
G1
Q
*
M
*
O
,
M
*
O
,
O
,
1
};
// Y layout [G0, G1
Q
, M, O]
std
::
vector
<
ck
::
index_t
>
z_gs_ms_ns_lengths
{
G0
,
G1
,
M
,
N
};
std
::
vector
<
ck
::
index_t
>
z_gs_ms_ns_lengths
{
G0
,
G1
Q
,
M
,
N
};
std
::
vector
<
ck
::
index_t
>
z_gs_ms_ns_strides
=
std
::
vector
<
ck
::
index_t
>
z_gs_ms_ns_strides
=
input_permute
input_permute
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
*
N
,
N
,
G1
*
N
,
1
}
// Z layout [G0, M, G1, N]
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
Q
*
N
,
N
,
G1
Q
*
N
,
1
}
// Z layout [G0, M, G1
Q
, N]
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
M
*
N
,
M
*
N
,
N
,
1
};
// Z layout [G0, G1, M, N]
:
std
::
vector
<
ck
::
index_t
>
{
G1
Q
*
M
*
N
,
M
*
N
,
N
,
1
};
// Z layout [G0, G1
Q
, M, N]
std
::
vector
<
ck
::
index_t
>
kgrad_gs_ns_ks_lengths
{
G0
,
G1
,
N
,
K
};
std
::
vector
<
ck
::
index_t
>
kgrad_gs_ns_ks_lengths
{
G0
,
G1
Q
,
N
,
K
};
std
::
vector
<
ck
::
index_t
>
kgrad_gs_ns_ks_strides
=
std
::
vector
<
ck
::
index_t
>
kgrad_gs_ns_ks_strides
=
input_permute
input_permute
?
std
::
vector
<
ck
::
index_t
>
{
N
*
G1Q
*
K
,
K
,
G1Q
*
K
,
1
}
?
std
::
vector
<
ck
::
index_t
>
{
N
*
G1
*
K
,
K
,
G1
*
K
,
1
}
// KGrad layout [G0, N, G1, K]
// KGrad layout [G0, N, G1Q, K]
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
N
*
K
,
N
*
K
,
K
,
1
};
// KGrad layout [G0, G1, N, K]
:
std
::
vector
<
ck
::
index_t
>
{
G1Q
*
N
*
K
,
N
*
K
,
K
,
1
};
// KGrad layout [G0, G1Q, N, K]
std
::
vector
<
ck
::
index_t
>
vgrad_gs_os_ns_lengths
{
G0
,
G1
,
O
,
N
};
std
::
vector
<
ck
::
index_t
>
vgrad_gs_os_ns_lengths
{
G0
,
G1
Q
,
O
,
N
};
std
::
vector
<
ck
::
index_t
>
vgrad_gs_os_ns_strides
=
std
::
vector
<
ck
::
index_t
>
vgrad_gs_os_ns_strides
=
input_permute
input_permute
?
std
::
vector
<
ck
::
index_t
>
{
N
*
G1Q
*
O
,
O
,
1
,
G1Q
*
O
}
?
std
::
vector
<
ck
::
index_t
>
{
N
*
G1
*
O
,
O
,
1
,
G1
*
O
}
// VGrad layout [G0, N, G1, O]
// VGrad layout [G0, N, G1Q, O]
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
N
*
O
,
N
*
O
,
1
,
O
};
// VGrad layout [G0, G1, N, O]
:
std
::
vector
<
ck
::
index_t
>
{
G1Q
*
N
*
O
,
N
*
O
,
1
,
O
};
// VGrad layout [G0, G1Q, N, O]
// The softmax stat log-sum-exp (LSE) is used to speed up softmax calculation in backward
// The softmax stat log-sum-exp (LSE) is used to speed up softmax calculation in backward
// pass Pi = exp(Si) / sum(exp(S0) + exp(S1) + ...)
// pass Pi = exp(Si) / sum(exp(S0) + exp(S1) + ...)
// = exp(Si) / exp(log(sum(exp() + ...)))
// = exp(Si) / exp(log(sum(exp() + ...)))
// = exp(Si - log(sum(exp() + ...)))
// = exp(Si - log(sum(exp() + ...)))
// ^^^^^^^^^^^^^^^^^^^^^
// ^^^^^^^^^^^^^^^^^^^^^
// LSE
// LSE
std
::
vector
<
ck
::
index_t
>
lse_gs_ms_lengths
{
G0
,
G1
,
M
};
std
::
vector
<
ck
::
index_t
>
lse_gs_ms_lengths
{
G0
,
G1
Q
,
M
};
std
::
vector
<
ck
::
index_t
>
lse_gs_ms_strides
{
G1
*
M
,
M
,
1
};
// LSE layout [G0, G1, M]
std
::
vector
<
ck
::
index_t
>
lse_gs_ms_strides
{
G1
Q
*
M
,
M
,
1
};
// LSE layout [G0, G1
Q
, M]
problem_descs_fwd
.
push_back
({
problem_descs_fwd
.
push_back
({
q_gs_ms_ks_lengths
,
q_gs_ms_ks_lengths
,
q_gs_ms_ks_strides
,
q_gs_ms_ks_strides
,
...
@@ -505,7 +509,7 @@ int run(int argc, char* argv[])
...
@@ -505,7 +509,7 @@ int run(int argc, char* argv[])
{},
// std::array<std::vector<ck::index_t>, 1>{acc1_biases_gs_ms_os_strides},
{},
// std::array<std::vector<ck::index_t>, 1>{acc1_biases_gs_ms_os_strides},
});
});
int
BatchCount
=
G0
*
G1
;
int
BatchCount
=
G0
*
G1
Q
;
flop_fwd
+=
(
size_t
(
M
)
*
N
*
K
*
2
+
size_t
(
M
)
*
N
*
O
*
2
)
*
BatchCount
;
flop_fwd
+=
(
size_t
(
M
)
*
N
*
K
*
2
+
size_t
(
M
)
*
N
*
O
*
2
)
*
BatchCount
;
num_byte_fwd
+=
(
sizeof
(
InputDataType
)
*
M
*
K
+
sizeof
(
InputDataType
)
*
K
*
N
+
num_byte_fwd
+=
(
sizeof
(
InputDataType
)
*
M
*
K
+
sizeof
(
InputDataType
)
*
K
*
N
+
sizeof
(
InputDataType
)
*
N
*
O
+
sizeof
(
InputDataType
)
*
M
*
O
)
*
sizeof
(
InputDataType
)
*
N
*
O
+
sizeof
(
InputDataType
)
*
M
*
O
)
*
...
@@ -574,14 +578,16 @@ int run(int argc, char* argv[])
...
@@ -574,14 +578,16 @@ int run(int argc, char* argv[])
q_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_1
<
InputDataType
>
{
1
});
q_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_1
<
InputDataType
>
{
1
});
k_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
InputDataType
>
{});
k_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
InputDataType
>
{});
v_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
InputDataType
>
{});
v_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
InputDataType
>
{});
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_Sequential
<
2
>
{});
// dy[g0, g1, m, o]
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_Sequential
<
2
>
{});
// dy[g0, g1q, m, o]
// dO dot O = [0; 1; 2; ...]
// dO dot O = [0; 1; 2; ...]
break
;
break
;
case
6
:
case
6
:
q_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_1
<
InputDataType
>
{
1
});
q_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_1
<
InputDataType
>
{
1
});
k_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
InputDataType
>
{});
k_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
InputDataType
>
{});
v_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
InputDataType
>
{});
v_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
InputDataType
>
{});
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_Sequential
<
3
>
{});
// dy[g0, g1, m, o]
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_Sequential
<
3
>
{});
// dy[g0, g1q, m, o]
// assume mnko = 256
// assume mnko = 256
// P = softmax(QK) = 0.0039 * ones
// P = softmax(QK) = 0.0039 * ones
// O = P V = 0.0039 * ones
// O = P V = 0.0039 * ones
...
@@ -595,7 +601,7 @@ int run(int argc, char* argv[])
...
@@ -595,7 +601,7 @@ int run(int argc, char* argv[])
k_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
InputDataType
>
{});
k_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
InputDataType
>
{});
v_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
InputDataType
>
{});
v_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
InputDataType
>
{});
ygrad_gs_ms_os
.
GenerateTensorValue
(
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_1
<
InputDataType
>
{
1
});
// dy[g0, g1, m, o]
GeneratorTensor_1
<
InputDataType
>
{
1
});
// dy[g0, g1
q
, m, o]
// assume mnko = 256
// assume mnko = 256
// P = softmax(QK) = 0.0039 * ones
// P = softmax(QK) = 0.0039 * ones
// O = P V = 0.0039 * ones
// O = P V = 0.0039 * ones
...
@@ -618,21 +624,21 @@ int run(int argc, char* argv[])
...
@@ -618,21 +624,21 @@ int run(int argc, char* argv[])
Tensor
<
InputDataType
>
p_drop_g_m_n
({
BatchCount
,
M
,
N
});
Tensor
<
InputDataType
>
p_drop_g_m_n
({
BatchCount
,
M
,
N
});
q_gs_ms_ks
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
q_gs_ms_ks
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
q_g_m_k
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
q_g_m_k
(
idx
[
0
]
*
G1
Q
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
});
});
k_g_n_k
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
k_g_n_k
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
const
size_t
&
g0
=
idx
[
0
]
/
G1
;
const
size_t
&
g0
=
idx
[
0
]
/
G1
Q
;
const
size_t
&
g1
=
idx
[
0
]
%
G1
;
const
size_t
&
g1
q
=
idx
[
0
]
%
G1
Q
;
const
size_t
&
g
2
=
g1
/
h_ratio
;
const
size_t
&
g
1kv
=
g1
q
/
h_ratio
;
self
(
idx
)
=
k_gs_ns_ks
(
g0
,
g
2
,
idx
[
1
],
idx
[
2
]);
self
(
idx
)
=
k_gs_ns_ks
(
g0
,
g
1kv
,
idx
[
1
],
idx
[
2
]);
});
});
v_g_n_o
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
v_g_n_o
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
const
size_t
&
g0
=
idx
[
0
]
/
G1
;
const
size_t
&
g0
=
idx
[
0
]
/
G1
Q
;
const
size_t
&
g1
=
idx
[
0
]
%
G1
;
const
size_t
&
g1
q
=
idx
[
0
]
%
G1
Q
;
const
size_t
&
g
2
=
g1
/
h_ratio
;
const
size_t
&
g
1kv
=
g1
q
/
h_ratio
;
self
(
idx
)
=
v_gs_os_ns
(
g0
,
g
2
,
idx
[
2
],
idx
[
1
]);
self
(
idx
)
=
v_gs_os_ns
(
g0
,
g
1kv
,
idx
[
2
],
idx
[
1
]);
});
});
q_g_m_ks
.
push_back
(
q_g_m_k
);
q_g_m_ks
.
push_back
(
q_g_m_k
);
...
@@ -872,15 +878,15 @@ int run(int argc, char* argv[])
...
@@ -872,15 +878,15 @@ int run(int argc, char* argv[])
for
(
std
::
size_t
i
=
0
;
i
<
group_count
;
i
++
)
for
(
std
::
size_t
i
=
0
;
i
<
group_count
;
i
++
)
{
{
int
G1
=
q_tensors
[
i
].
GetLengths
()[
1
];
int
G1
Q
=
q_tensors
[
i
].
GetLengths
()[
1
];
// copy z matirx data form device
// copy z matirx data form device
z_fwd_tensors_device
[
i
]
->
FromDevice
(
z_fwd_tensors
[
i
].
mData
.
data
());
z_fwd_tensors_device
[
i
]
->
FromDevice
(
z_fwd_tensors
[
i
].
mData
.
data
());
z_fwd_tensors
[
i
].
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
z_fwd_tensors
[
i
].
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
z_fwd_g_m_ns
[
i
](
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
z_fwd_g_m_ns
[
i
](
idx
[
0
]
*
G1
Q
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
});
});
z_bwd_tensors_device
[
i
]
->
FromDevice
(
z_bwd_tensors
[
i
].
mData
.
data
());
z_bwd_tensors_device
[
i
]
->
FromDevice
(
z_bwd_tensors
[
i
].
mData
.
data
());
z_bwd_tensors
[
i
].
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
z_bwd_tensors
[
i
].
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
z_bwd_g_m_ns
[
i
](
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
z_bwd_g_m_ns
[
i
](
idx
[
0
]
*
G1
Q
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
});
});
run_attention_fwd_host
(
q_g_m_ks
[
i
],
run_attention_fwd_host
(
q_g_m_ks
[
i
],
k_g_n_ks
[
i
],
k_g_n_ks
[
i
],
...
@@ -900,7 +906,7 @@ int run(int argc, char* argv[])
...
@@ -900,7 +906,7 @@ int run(int argc, char* argv[])
int
N
=
v_tensors
[
i
].
GetLengths
()[
3
];
int
N
=
v_tensors
[
i
].
GetLengths
()[
3
];
int
M
=
q_tensors
[
i
].
GetLengths
()[
2
];
int
M
=
q_tensors
[
i
].
GetLengths
()[
2
];
int
K
=
q_tensors
[
i
].
GetLengths
()[
3
];
int
K
=
q_tensors
[
i
].
GetLengths
()[
3
];
int
BatchCount
=
G0
*
G1
;
int
BatchCount
=
G0
*
G1
Q
;
Tensor
<
OutputDataType
>
qgrad_g_m_k
({
BatchCount
,
M
,
K
});
Tensor
<
OutputDataType
>
qgrad_g_m_k
({
BatchCount
,
M
,
K
});
Tensor
<
OutputDataType
>
kgrad_g_n_k
({
BatchCount
,
N
,
K
});
Tensor
<
OutputDataType
>
kgrad_g_n_k
({
BatchCount
,
N
,
K
});
Tensor
<
OutputDataType
>
vgrad_g_n_o
({
BatchCount
,
N
,
O
});
Tensor
<
OutputDataType
>
vgrad_g_n_o
({
BatchCount
,
N
,
O
});
...
@@ -910,7 +916,7 @@ int run(int argc, char* argv[])
...
@@ -910,7 +916,7 @@ int run(int argc, char* argv[])
Tensor
<
InputDataType
>
ygrad_g_m_o
({
BatchCount
,
M
,
O
});
Tensor
<
InputDataType
>
ygrad_g_m_o
({
BatchCount
,
M
,
O
});
ygrad_tensors
[
i
].
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
ygrad_tensors
[
i
].
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
ygrad_g_m_o
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
ygrad_g_m_o
(
idx
[
0
]
*
G1
Q
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
});
});
auto
ref_gemm0_grad
=
ReferenceGemm0GradInstance
{};
auto
ref_gemm0_grad
=
ReferenceGemm0GradInstance
{};
auto
ref_gemm0_grad_invoker
=
ref_gemm0_grad
.
MakeInvoker
();
auto
ref_gemm0_grad_invoker
=
ref_gemm0_grad
.
MakeInvoker
();
...
@@ -981,42 +987,42 @@ int run(int argc, char* argv[])
...
@@ -981,42 +987,42 @@ int run(int argc, char* argv[])
// permute
// permute
y_gs_ms_os_host_result
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
y_gs_ms_os_host_result
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
const
size_t
&
g0
=
idx
[
0
];
const
size_t
&
g0
=
idx
[
0
];
const
size_t
&
g1
=
idx
[
1
];
const
size_t
&
g1
q
=
idx
[
1
];
const
size_t
g
=
g0
*
G1
+
g1
;
const
size_t
g
=
g0
*
G1
Q
+
g1
q
;
self
(
idx
)
=
y_g_m_os
[
i
](
g
,
idx
[
2
],
idx
[
3
]);
self
(
idx
)
=
y_g_m_os
[
i
](
g
,
idx
[
2
],
idx
[
3
]);
});
});
lse_gs_ms_host_result
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
lse_gs_ms_host_result
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
const
size_t
&
g0
=
idx
[
0
];
const
size_t
&
g0
=
idx
[
0
];
const
size_t
&
g1
=
idx
[
1
];
const
size_t
&
g1
q
=
idx
[
1
];
const
size_t
g
=
g0
*
G1
+
g1
;
const
size_t
g
=
g0
*
G1
Q
+
g1
q
;
self
(
idx
)
=
lse_g_ms
[
i
](
g
,
idx
[
2
]);
self
(
idx
)
=
lse_g_ms
[
i
](
g
,
idx
[
2
]);
});
});
qgrad_gs_ms_ks_host_result
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
qgrad_gs_ms_ks_host_result
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
const
size_t
&
g0
=
idx
[
0
];
const
size_t
&
g0
=
idx
[
0
];
const
size_t
&
g1
=
idx
[
1
];
const
size_t
&
g1
q
=
idx
[
1
];
const
size_t
g
=
g0
*
G1
+
g1
;
const
size_t
g
=
g0
*
G1
Q
+
g1
q
;
self
(
idx
)
=
qgrad_g_m_k
(
g
,
idx
[
2
],
idx
[
3
]);
self
(
idx
)
=
qgrad_g_m_k
(
g
,
idx
[
2
],
idx
[
3
]);
});
});
kgrad_gs_ns_ks_host_result
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
kgrad_gs_ns_ks_host_result
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
const
size_t
&
g0
=
idx
[
0
];
const
size_t
&
g0
=
idx
[
0
];
const
size_t
&
g1
=
idx
[
1
];
const
size_t
&
g1
q
=
idx
[
1
];
const
size_t
g
=
g0
*
G1
+
g1
;
const
size_t
g
=
g0
*
G1
Q
+
g1
q
;
self
(
idx
)
=
kgrad_g_n_k
(
g
,
idx
[
2
],
idx
[
3
]);
self
(
idx
)
=
kgrad_g_n_k
(
g
,
idx
[
2
],
idx
[
3
]);
});
});
vgrad_gs_os_ns_host_result
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
vgrad_gs_os_ns_host_result
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
const
size_t
&
g0
=
idx
[
0
];
const
size_t
&
g0
=
idx
[
0
];
const
size_t
&
g1
=
idx
[
1
];
const
size_t
&
g1
q
=
idx
[
1
];
const
size_t
g
=
g0
*
G1
+
g1
;
const
size_t
g
=
g0
*
G1
Q
+
g1
q
;
self
(
idx
)
=
vgrad_g_n_o
(
g
,
idx
[
3
],
idx
[
2
]);
self
(
idx
)
=
vgrad_g_n_o
(
g
,
idx
[
3
],
idx
[
2
]);
});
});
...
...
example/32_batched_gemm_scale_softmax_gemm/run_batched_multihead_attention_forward.inc
View file @
10836d41
...
@@ -14,12 +14,12 @@ int run(int argc, char* argv[])
...
@@ -14,12 +14,12 @@ int run(int argc, char* argv[])
ck
::
index_t
K
=
DIM
;
ck
::
index_t
K
=
DIM
;
ck
::
index_t
O
=
DIM
;
ck
::
index_t
O
=
DIM
;
// Output shape C[G0, M, G1, O]. Batch dim, outer dim, inner dim must match GEMM shape
// Output shape C[G0, M, G1
Q
, O]. Batch dim, outer dim, inner dim must match GEMM shape
// C_g0_g1_m_o = reshape(C_g_m_o, [g0, g1, m, o])
// C_g0_g1
q
_m_o = reshape(C_g_m_o, [g0, g1
q
, m, o])
// C_g0_m_g1_o = permute(C_g0_g1_m_o, [0, 2, 1, 3])
// C_g0_m_g1
q
_o = permute(C_g0_g1
q
_m_o, [0, 2, 1, 3])
ck
::
index_t
G0
=
7
;
ck
::
index_t
G0
=
7
;
ck
::
index_t
G1
=
12
;
// h_q
ck
::
index_t
G1
Q
=
12
;
// h_q
ck
::
index_t
G
2
=
12
;
// h_kv
ck
::
index_t
G
1KV
=
12
;
// h_kv
bool
input_permute
=
false
;
bool
input_permute
=
false
;
bool
output_permute
=
true
;
bool
output_permute
=
true
;
...
@@ -44,13 +44,13 @@ int run(int argc, char* argv[])
...
@@ -44,13 +44,13 @@ int run(int argc, char* argv[])
init_method
=
std
::
stoi
(
argv
[
2
]);
init_method
=
std
::
stoi
(
argv
[
2
]);
time_kernel
=
std
::
stoi
(
argv
[
3
]);
time_kernel
=
std
::
stoi
(
argv
[
3
]);
M
=
std
::
stoi
(
argv
[
4
]);
M
=
std
::
stoi
(
argv
[
4
]);
N
=
std
::
stoi
(
argv
[
5
]);
N
=
std
::
stoi
(
argv
[
5
]);
K
=
std
::
stoi
(
argv
[
6
]);
K
=
std
::
stoi
(
argv
[
6
]);
O
=
std
::
stoi
(
argv
[
7
]);
O
=
std
::
stoi
(
argv
[
7
]);
G0
=
std
::
stoi
(
argv
[
8
]);
G0
=
std
::
stoi
(
argv
[
8
]);
G1
=
std
::
stoi
(
argv
[
9
]);
G1
Q
=
std
::
stoi
(
argv
[
9
]);
G
2
=
std
::
stoi
(
argv
[
10
]);
G
1KV
=
std
::
stoi
(
argv
[
10
]);
p_drop
=
std
::
stof
(
argv
[
11
]);
p_drop
=
std
::
stof
(
argv
[
11
]);
...
@@ -62,7 +62,7 @@ int run(int argc, char* argv[])
...
@@ -62,7 +62,7 @@ int run(int argc, char* argv[])
printf
(
"arg1: verification (0=no, 1=yes)
\n
"
);
printf
(
"arg1: verification (0=no, 1=yes)
\n
"
);
printf
(
"arg2: initialization (0=no init, 1=integer value, 2=decimal value)
\n
"
);
printf
(
"arg2: initialization (0=no init, 1=integer value, 2=decimal value)
\n
"
);
printf
(
"arg3: time kernel (0=no, 1=yes)
\n
"
);
printf
(
"arg3: time kernel (0=no, 1=yes)
\n
"
);
printf
(
"arg4 to 10: M, N, K, O, G0, G1, G
2
\n
"
);
printf
(
"arg4 to 10: M, N, K, O, G0, G1
Q
, G
1KV
\n
"
);
printf
(
"arg11: p_drop
\n
"
);
printf
(
"arg11: p_drop
\n
"
);
printf
(
"arg12 to 13: input / output permute
\n
"
);
printf
(
"arg12 to 13: input / output permute
\n
"
);
exit
(
0
);
exit
(
0
);
...
@@ -73,39 +73,39 @@ int run(int argc, char* argv[])
...
@@ -73,39 +73,39 @@ int run(int argc, char* argv[])
float
rp_dropout
=
1.0
/
p_dropout
;
float
rp_dropout
=
1.0
/
p_dropout
;
float
alpha
=
1.
f
/
std
::
sqrt
(
K
);
float
alpha
=
1.
f
/
std
::
sqrt
(
K
);
std
::
vector
<
ck
::
index_t
>
a_gs_ms_ks_lengths
{
G0
,
G1
,
M
,
K
};
std
::
vector
<
ck
::
index_t
>
a_gs_ms_ks_lengths
{
G0
,
G1
Q
,
M
,
K
};
std
::
vector
<
ck
::
index_t
>
a_gs_ms_ks_strides
=
std
::
vector
<
ck
::
index_t
>
a_gs_ms_ks_strides
=
input_permute
input_permute
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
*
K
,
K
,
G1
*
K
,
1
}
// A layout [G0, M, G1, K]
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
Q
*
K
,
K
,
G1
Q
*
K
,
1
}
// A layout [G0, M, G1
Q
, K]
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
M
*
K
,
M
*
K
,
K
,
1
};
// A layout [G0, G1, M, K]
:
std
::
vector
<
ck
::
index_t
>
{
G1
Q
*
M
*
K
,
M
*
K
,
K
,
1
};
// A layout [G0, G1
Q
, M, K]
std
::
vector
<
ck
::
index_t
>
b0_gs_ns_ks_lengths
{
G0
,
G
2
,
N
,
K
};
std
::
vector
<
ck
::
index_t
>
b0_gs_ns_ks_lengths
{
G0
,
G
1KV
,
N
,
K
};
std
::
vector
<
ck
::
index_t
>
b0_gs_ns_ks_strides
=
std
::
vector
<
ck
::
index_t
>
b0_gs_ns_ks_strides
=
input_permute
input_permute
?
std
::
vector
<
ck
::
index_t
>
{
N
*
G
2
*
K
,
K
,
G
2
*
K
,
1
}
// B0 layout [G0, N, G
2
, K]
?
std
::
vector
<
ck
::
index_t
>
{
N
*
G
1KV
*
K
,
K
,
G
1KV
*
K
,
1
}
// B0 layout [G0, N, G
1KV
, K]
:
std
::
vector
<
ck
::
index_t
>
{
G
2
*
N
*
K
,
N
*
K
,
K
,
1
};
// B0 layout [G0, G
2
, N, K]
:
std
::
vector
<
ck
::
index_t
>
{
G
1KV
*
N
*
K
,
N
*
K
,
K
,
1
};
// B0 layout [G0, G
1KV
, N, K]
std
::
vector
<
ck
::
index_t
>
b1_gs_os_ns_lengths
{
G0
,
G
2
,
O
,
N
};
std
::
vector
<
ck
::
index_t
>
b1_gs_os_ns_lengths
{
G0
,
G
1KV
,
O
,
N
};
std
::
vector
<
ck
::
index_t
>
b1_gs_os_ns_strides
=
std
::
vector
<
ck
::
index_t
>
b1_gs_os_ns_strides
=
input_permute
input_permute
?
std
::
vector
<
ck
::
index_t
>
{
N
*
G
2
*
O
,
O
,
1
,
G
2
*
O
}
// B1 layout [G0, N, G
2
, O]
?
std
::
vector
<
ck
::
index_t
>
{
N
*
G
1KV
*
O
,
O
,
1
,
G
1KV
*
O
}
// B1 layout [G0, N, G
1KV
, O]
:
std
::
vector
<
ck
::
index_t
>
{
G
2
*
N
*
O
,
N
*
O
,
1
,
O
};
// B1 layout [G0, G
2
, N, O]
:
std
::
vector
<
ck
::
index_t
>
{
G
1KV
*
N
*
O
,
N
*
O
,
1
,
O
};
// B1 layout [G0, G
1KV
, N, O]
std
::
vector
<
ck
::
index_t
>
c_gs_ms_os_lengths
{
G0
,
G1
,
M
,
O
};
std
::
vector
<
ck
::
index_t
>
c_gs_ms_os_lengths
{
G0
,
G1
Q
,
M
,
O
};
std
::
vector
<
ck
::
index_t
>
c_gs_ms_os_strides
=
std
::
vector
<
ck
::
index_t
>
c_gs_ms_os_strides
=
output_permute
output_permute
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
*
O
,
O
,
G1
*
O
,
1
}
// C layout [G0, M, G1, O]
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
Q
*
O
,
O
,
G1
Q
*
O
,
1
}
// C layout [G0, M, G1
Q
, O]
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
M
*
O
,
M
*
O
,
O
,
1
};
// C layout [G0, G1, M, O]
:
std
::
vector
<
ck
::
index_t
>
{
G1
Q
*
M
*
O
,
M
*
O
,
O
,
1
};
// C layout [G0, G1
Q
, M, O]
std
::
vector
<
ck
::
index_t
>
z_gs_ms_ns_lengths
{
G0
,
G1
,
M
,
N
};
std
::
vector
<
ck
::
index_t
>
z_gs_ms_ns_lengths
{
G0
,
G1
Q
,
M
,
N
};
std
::
vector
<
ck
::
index_t
>
z_gs_ms_ns_strides
=
std
::
vector
<
ck
::
index_t
>
z_gs_ms_ns_strides
=
input_permute
input_permute
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
*
N
,
N
,
G1
*
N
,
1
}
// Z layout [G0, M, G1, N]
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
Q
*
N
,
N
,
G1
Q
*
N
,
1
}
// Z layout [G0, M, G1
Q
, N]
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
M
*
N
,
M
*
N
,
N
,
1
};
// Z layout [G0, G1, M, N]
:
std
::
vector
<
ck
::
index_t
>
{
G1
Q
*
M
*
N
,
M
*
N
,
N
,
1
};
// Z layout [G0, G1
Q
, M, N]
std
::
vector
<
ck
::
index_t
>
lse_gs_ms_lengths
{
G0
,
G1
,
M
};
std
::
vector
<
ck
::
index_t
>
lse_gs_ms_lengths
{
G0
,
G1
Q
,
M
};
std
::
vector
<
ck
::
index_t
>
lse_gs_ms_strides
=
std
::
vector
<
ck
::
index_t
>
lse_gs_ms_strides
=
std
::
vector
<
ck
::
index_t
>
{
G1
*
M
,
M
,
1
};
// LSE layout [G0, G1, M]
std
::
vector
<
ck
::
index_t
>
{
G1
Q
*
M
,
M
,
1
};
// LSE layout [G0, G1
Q
, M]
Tensor
<
ADataType
>
a_gs_ms_ks
(
a_gs_ms_ks_lengths
,
a_gs_ms_ks_strides
);
Tensor
<
ADataType
>
a_gs_ms_ks
(
a_gs_ms_ks_lengths
,
a_gs_ms_ks_strides
);
Tensor
<
B0DataType
>
b0_gs_ns_ks
(
b0_gs_ns_ks_lengths
,
b0_gs_ns_ks_strides
);
Tensor
<
B0DataType
>
b0_gs_ns_ks
(
b0_gs_ns_ks_lengths
,
b0_gs_ns_ks_strides
);
...
@@ -213,7 +213,7 @@ int run(int argc, char* argv[])
...
@@ -213,7 +213,7 @@ int run(int argc, char* argv[])
return
0
;
return
0
;
}
}
ck
::
index_t
BatchCount
=
G0
*
G1
;
ck
::
index_t
BatchCount
=
G0
*
G1
Q
;
float
ave_time
=
invoker
.
Run
(
argument
,
StreamConfig
{
nullptr
,
time_kernel
});
float
ave_time
=
invoker
.
Run
(
argument
,
StreamConfig
{
nullptr
,
time_kernel
});
...
@@ -278,32 +278,32 @@ int run(int argc, char* argv[])
...
@@ -278,32 +278,32 @@ int run(int argc, char* argv[])
Tensor
<
B1DataType
>
b1_g_n_o
({
BatchCount
,
N
,
O
});
Tensor
<
B1DataType
>
b1_g_n_o
({
BatchCount
,
N
,
O
});
Tensor
<
AccDataType
>
acc0_g_m_n
({
BatchCount
,
M
,
N
});
// scratch object after gemm0
Tensor
<
AccDataType
>
acc0_g_m_n
({
BatchCount
,
M
,
N
});
// scratch object after gemm0
Tensor
<
ADataType
>
a1_g_m_n
({
BatchCount
,
M
,
N
});
// scratch object after softmax
Tensor
<
ADataType
>
a1_g_m_n
({
BatchCount
,
M
,
N
});
// scratch object after softmax
Tensor
<
ADataType
>
a1_g_m_n_drop
({
G0
*
G1
,
M
,
N
});
Tensor
<
ADataType
>
a1_g_m_n_drop
({
BatchCount
,
M
,
N
});
Tensor
<
LSEDataType
>
lse_g_m_host_result
(
Tensor
<
LSEDataType
>
lse_g_m_host_result
(
{
BatchCount
,
M
});
// scratch object after max + ln(sum)
{
BatchCount
,
M
});
// scratch object after max + ln(sum)
Tensor
<
ZDataType
>
z_g_m_n
({
G0
*
G1
,
M
,
N
});
Tensor
<
ZDataType
>
z_g_m_n
({
BatchCount
,
M
,
N
});
Tensor
<
CDataType
>
c_g_m_o_host_result
({
BatchCount
,
M
,
O
});
// scratch object after gemm1
Tensor
<
CDataType
>
c_g_m_o_host_result
({
BatchCount
,
M
,
O
});
// scratch object after gemm1
// permute
// permute
a_gs_ms_ks
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
a_gs_ms_ks
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
a_g_m_k
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
a_g_m_k
(
idx
[
0
]
*
G1
Q
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
});
});
b0_g_k_n
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
b0_g_k_n
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
const
size_t
&
g0
=
idx
[
0
]
/
G1
;
const
size_t
&
g0
=
idx
[
0
]
/
G1
Q
;
const
size_t
&
g1
=
idx
[
0
]
%
G1
;
const
size_t
&
g1
q
=
idx
[
0
]
%
G1
Q
;
const
size_t
&
g
2
=
g1
/
(
G1
/
G
2
);
const
size_t
&
g
1kv
=
g1
q
/
(
G1
Q
/
G
1KV
);
self
(
idx
)
=
b0_gs_ns_ks
(
g0
,
g
2
,
idx
[
2
],
idx
[
1
]);
self
(
idx
)
=
b0_gs_ns_ks
(
g0
,
g
1kv
,
idx
[
2
],
idx
[
1
]);
});
});
b1_g_n_o
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
b1_g_n_o
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
const
size_t
&
g0
=
idx
[
0
]
/
G1
;
const
size_t
&
g0
=
idx
[
0
]
/
G1
Q
;
const
size_t
&
g1
=
idx
[
0
]
%
G1
;
const
size_t
&
g1
q
=
idx
[
0
]
%
G1
Q
;
const
size_t
&
g
2
=
g1
/
(
G1
/
G
2
);
const
size_t
&
g
1kv
=
g1
q
/
(
G1
Q
/
G
1KV
);
self
(
idx
)
=
b1_gs_os_ns
(
g0
,
g
2
,
idx
[
2
],
idx
[
1
]);
self
(
idx
)
=
b1_gs_os_ns
(
g0
,
g
1kv
,
idx
[
2
],
idx
[
1
]);
});
});
z_gs_ms_ns
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
z_gs_ms_ns
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
z_g_m_n
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
z_g_m_n
(
idx
[
0
]
*
G1
Q
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
});
});
// gemm 0
// gemm 0
...
@@ -350,18 +350,18 @@ int run(int argc, char* argv[])
...
@@ -350,18 +350,18 @@ int run(int argc, char* argv[])
// permute
// permute
c_gs_ms_os_host_result
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
c_gs_ms_os_host_result
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
const
size_t
&
g0
=
idx
[
0
];
const
size_t
&
g0
=
idx
[
0
];
const
size_t
&
g1
=
idx
[
1
];
const
size_t
&
g1
q
=
idx
[
1
];
const
size_t
g
=
g0
*
G1
+
g1
;
const
size_t
g
=
g0
*
G1
Q
+
g1
q
;
self
(
idx
)
=
c_g_m_o_host_result
(
g
,
idx
[
2
],
idx
[
3
]);
self
(
idx
)
=
c_g_m_o_host_result
(
g
,
idx
[
2
],
idx
[
3
]);
});
});
lse_gs_ms_host_result
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
lse_gs_ms_host_result
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
const
size_t
&
g0
=
idx
[
0
];
const
size_t
&
g0
=
idx
[
0
];
const
size_t
&
g1
=
idx
[
1
];
const
size_t
&
g1
q
=
idx
[
1
];
const
size_t
g
=
g0
*
G1
+
g1
;
const
size_t
g
=
g0
*
G1
Q
+
g1
q
;
self
(
idx
)
=
lse_g_m_host_result
(
g
,
idx
[
2
]);
self
(
idx
)
=
lse_g_m_host_result
(
g
,
idx
[
2
]);
});
});
...
...
example/32_batched_gemm_scale_softmax_gemm/run_grouped_multihead_attention_forward.inc
View file @
10836d41
...
@@ -11,7 +11,7 @@ int run(int argc, char* argv[])
...
@@ -11,7 +11,7 @@ int run(int argc, char* argv[])
bool
output_permute
=
true
;
bool
output_permute
=
true
;
float
p_drop
=
0.2
;
float
p_drop
=
0.2
;
int
h_ratio
=
1
;
// G1 / G
2
int
h_ratio
=
1
;
// G1
Q
/ G
1KV
const
unsigned
long
long
seed
=
1
;
const
unsigned
long
long
seed
=
1
;
const
unsigned
long
long
offset
=
0
;
const
unsigned
long
long
offset
=
0
;
...
@@ -64,7 +64,7 @@ int run(int argc, char* argv[])
...
@@ -64,7 +64,7 @@ int run(int argc, char* argv[])
std
::
vector
<
void
*>
p_z
;
// for result verification
std
::
vector
<
void
*>
p_z
;
// for result verification
std
::
vector
<
void
*>
p_z_nullptr
;
// for time test
std
::
vector
<
void
*>
p_z_nullptr
;
// for time test
std
::
vector
<
void
*>
p_lse
;
std
::
vector
<
void
*>
p_lse
;
std
::
vector
<
std
::
vector
<
int
>>
g0_g1_m_n_k_o
;
std
::
vector
<
std
::
vector
<
int
>>
g0_g1
q
_m_n_k_o
;
std
::
vector
<
Tensor
<
ADataType
>>
a_tensors
;
std
::
vector
<
Tensor
<
ADataType
>>
a_tensors
;
std
::
vector
<
Tensor
<
B0DataType
>>
b0_tensors
;
std
::
vector
<
Tensor
<
B0DataType
>>
b0_tensors
;
...
@@ -87,49 +87,51 @@ int run(int argc, char* argv[])
...
@@ -87,49 +87,51 @@ int run(int argc, char* argv[])
for
(
std
::
size_t
i
=
0
;
i
<
group_count
;
i
++
)
for
(
std
::
size_t
i
=
0
;
i
<
group_count
;
i
++
)
{
{
int
M
=
128
*
(
rand
()
%
8
)
+
(
rand
()
%
128
);
int
M
=
128
*
(
rand
()
%
8
)
+
(
rand
()
%
128
);
int
N
=
128
*
(
rand
()
%
8
)
+
(
rand
()
%
128
);
int
N
=
128
*
(
rand
()
%
8
)
+
(
rand
()
%
128
);
int
K
=
DIM
;
int
K
=
DIM
;
int
O
=
DIM
;
int
O
=
DIM
;
int
G0
=
rand
()
%
3
+
1
;
int
G0
=
rand
()
%
3
+
1
;
int
G
2
=
rand
()
%
5
+
1
;
int
G
1KV
=
rand
()
%
5
+
1
;
int
G1
=
G
2
*
h_ratio
;
int
G1
Q
=
G
1KV
*
h_ratio
;
g0_g1_m_n_k_o
.
push_back
({
G0
,
G1
,
M
,
N
,
K
,
O
});
g0_g1
q
_m_n_k_o
.
push_back
({
G0
,
G1
Q
,
M
,
N
,
K
,
O
});
std
::
vector
<
ck
::
index_t
>
a_gs_ms_ks_lengths
{
G0
,
G1
,
M
,
K
};
std
::
vector
<
ck
::
index_t
>
a_gs_ms_ks_lengths
{
G0
,
G1
Q
,
M
,
K
};
std
::
vector
<
ck
::
index_t
>
a_gs_ms_ks_strides
=
std
::
vector
<
ck
::
index_t
>
a_gs_ms_ks_strides
=
input_permute
input_permute
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
*
K
,
K
,
G1
*
K
,
1
}
// A layout [G0, M, G1, K]
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
Q
*
K
,
K
,
G1
Q
*
K
,
1
}
// A layout [G0, M, G1
Q
, K]
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
M
*
K
,
M
*
K
,
K
,
1
};
// A layout [G0, G1, M, K]
:
std
::
vector
<
ck
::
index_t
>
{
G1
Q
*
M
*
K
,
M
*
K
,
K
,
1
};
// A layout [G0, G1
Q
, M, K]
std
::
vector
<
ck
::
index_t
>
b0_gs_ns_ks_lengths
{
G0
,
G
2
,
N
,
K
};
std
::
vector
<
ck
::
index_t
>
b0_gs_ns_ks_lengths
{
G0
,
G
1KV
,
N
,
K
};
std
::
vector
<
ck
::
index_t
>
b0_gs_ns_ks_strides
=
std
::
vector
<
ck
::
index_t
>
b0_gs_ns_ks_strides
=
input_permute
input_permute
?
std
::
vector
<
ck
::
index_t
>
{
N
*
G2
*
K
,
K
,
G2
*
K
,
1
}
// B0 layout [G0, N, G2, K]
?
std
::
vector
<
ck
::
index_t
>
{
N
*
G1KV
*
K
,
K
,
G1KV
*
K
,
1
}
:
std
::
vector
<
ck
::
index_t
>
{
G2
*
N
*
K
,
N
*
K
,
K
,
1
};
// B0 layout [G0, G2, N, K]
// B0 layout [G0, N, G1KV, K]
:
std
::
vector
<
ck
::
index_t
>
{
G1KV
*
N
*
K
,
N
*
K
,
K
,
1
};
// B0 layout [G0, G1KV, N, K]
std
::
vector
<
ck
::
index_t
>
b1_gs_os_ns_lengths
{
G0
,
G
2
,
O
,
N
};
std
::
vector
<
ck
::
index_t
>
b1_gs_os_ns_lengths
{
G0
,
G
1KV
,
O
,
N
};
std
::
vector
<
ck
::
index_t
>
b1_gs_os_ns_strides
=
std
::
vector
<
ck
::
index_t
>
b1_gs_os_ns_strides
=
input_permute
input_permute
?
std
::
vector
<
ck
::
index_t
>
{
N
*
G2
*
O
,
O
,
1
,
G2
*
O
}
// B1 layout [G0, N, G2, O]
?
std
::
vector
<
ck
::
index_t
>
{
N
*
G1KV
*
O
,
O
,
1
,
G1KV
*
O
}
:
std
::
vector
<
ck
::
index_t
>
{
G2
*
N
*
O
,
N
*
O
,
1
,
O
};
// B1 layout [G0, G2, N, O]
// B1 layout [G0, N, G1KV, O]
:
std
::
vector
<
ck
::
index_t
>
{
G1KV
*
N
*
O
,
N
*
O
,
1
,
O
};
// B1 layout [G0, G1KV, N, O]
std
::
vector
<
ck
::
index_t
>
c_gs_ms_os_lengths
{
G0
,
G1
,
M
,
O
};
std
::
vector
<
ck
::
index_t
>
c_gs_ms_os_lengths
{
G0
,
G1
Q
,
M
,
O
};
std
::
vector
<
ck
::
index_t
>
c_gs_ms_os_strides
=
std
::
vector
<
ck
::
index_t
>
c_gs_ms_os_strides
=
output_permute
output_permute
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
*
O
,
O
,
G1
*
O
,
1
}
// C layout [G0, M, G1, O]
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
Q
*
O
,
O
,
G1
Q
*
O
,
1
}
// C layout [G0, M, G1
Q
, O]
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
M
*
O
,
M
*
O
,
O
,
1
};
// C layout [G0, G1, M, O]
:
std
::
vector
<
ck
::
index_t
>
{
G1
Q
*
M
*
O
,
M
*
O
,
O
,
1
};
// C layout [G0, G1
Q
, M, O]
std
::
vector
<
ck
::
index_t
>
z_gs_ms_ns_lengths
{
G0
,
G1
,
M
,
N
};
std
::
vector
<
ck
::
index_t
>
z_gs_ms_ns_lengths
{
G0
,
G1
Q
,
M
,
N
};
std
::
vector
<
ck
::
index_t
>
z_gs_ms_ns_strides
=
std
::
vector
<
ck
::
index_t
>
z_gs_ms_ns_strides
=
input_permute
input_permute
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
*
N
,
N
,
G1
*
N
,
1
}
// Z layout [G0, M, G1, N]
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
Q
*
N
,
N
,
G1
Q
*
N
,
1
}
// Z layout [G0, M, G1
Q
, N]
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
M
*
N
,
M
*
N
,
N
,
1
};
// Z layout [G0, G1, M, N]
:
std
::
vector
<
ck
::
index_t
>
{
G1
Q
*
M
*
N
,
M
*
N
,
N
,
1
};
// Z layout [G0, G1
Q
, M, N]
std
::
vector
<
ck
::
index_t
>
lse_gs_ms_lengths
{
G0
,
G1
,
M
};
std
::
vector
<
ck
::
index_t
>
lse_gs_ms_lengths
{
G0
,
G1
Q
,
M
};
std
::
vector
<
ck
::
index_t
>
lse_gs_ms_strides
=
std
::
vector
<
ck
::
index_t
>
lse_gs_ms_strides
=
std
::
vector
<
ck
::
index_t
>
{
G1
*
M
,
M
,
1
};
// LSE layout [G0, G1, M]
std
::
vector
<
ck
::
index_t
>
{
G1
Q
*
M
,
M
,
1
};
// LSE layout [G0, G1
Q
, M]
problem_descs
.
push_back
({
a_gs_ms_ks_lengths
,
problem_descs
.
push_back
({
a_gs_ms_ks_lengths
,
a_gs_ms_ks_strides
,
a_gs_ms_ks_strides
,
...
@@ -156,7 +158,7 @@ int run(int argc, char* argv[])
...
@@ -156,7 +158,7 @@ int run(int argc, char* argv[])
Tensor
<
ZDataType
>
z_gs_ms_ns
(
z_gs_ms_ns_lengths
,
z_gs_ms_ns_strides
);
Tensor
<
ZDataType
>
z_gs_ms_ns
(
z_gs_ms_ns_lengths
,
z_gs_ms_ns_strides
);
Tensor
<
LSEDataType
>
lse_gs_ms_device_result
(
lse_gs_ms_lengths
,
lse_gs_ms_strides
);
Tensor
<
LSEDataType
>
lse_gs_ms_device_result
(
lse_gs_ms_lengths
,
lse_gs_ms_strides
);
int
Batch
=
G0
*
G1
;
int
Batch
=
G0
*
G1
Q
;
flop
+=
(
size_t
(
M
)
*
N
*
K
*
2
+
size_t
(
M
)
*
N
*
O
*
2
)
*
Batch
;
flop
+=
(
size_t
(
M
)
*
N
*
K
*
2
+
size_t
(
M
)
*
N
*
O
*
2
)
*
Batch
;
num_byte
+=
(
sizeof
(
ADataType
)
*
M
*
K
+
sizeof
(
B0DataType
)
*
K
*
N
+
num_byte
+=
(
sizeof
(
ADataType
)
*
M
*
K
+
sizeof
(
B0DataType
)
*
K
*
N
+
sizeof
(
B1DataType
)
*
N
*
O
+
sizeof
(
CDataType
)
*
M
*
O
)
*
sizeof
(
B1DataType
)
*
N
*
O
+
sizeof
(
CDataType
)
*
M
*
O
)
*
...
@@ -313,12 +315,12 @@ int run(int argc, char* argv[])
...
@@ -313,12 +315,12 @@ int run(int argc, char* argv[])
for
(
std
::
size_t
i
=
0
;
i
<
group_count
;
i
++
)
for
(
std
::
size_t
i
=
0
;
i
<
group_count
;
i
++
)
{
{
const
int
&
G0
=
g0_g1_m_n_k_o
[
i
][
0
];
const
int
&
G0
=
g0_g1
q
_m_n_k_o
[
i
][
0
];
const
int
&
G1
=
g0_g1_m_n_k_o
[
i
][
1
];
const
int
&
G1
Q
=
g0_g1
q
_m_n_k_o
[
i
][
1
];
const
int
&
M
=
g0_g1_m_n_k_o
[
i
][
2
];
const
int
&
M
=
g0_g1
q
_m_n_k_o
[
i
][
2
];
const
int
&
N
=
g0_g1_m_n_k_o
[
i
][
3
];
const
int
&
N
=
g0_g1
q
_m_n_k_o
[
i
][
3
];
const
int
&
K
=
g0_g1_m_n_k_o
[
i
][
4
];
const
int
&
K
=
g0_g1
q
_m_n_k_o
[
i
][
4
];
const
int
&
O
=
g0_g1_m_n_k_o
[
i
][
5
];
const
int
&
O
=
g0_g1
q
_m_n_k_o
[
i
][
5
];
const
auto
&
c_gs_ms_os_lengths
=
problem_descs
[
i
]
.
c_gs_ms_os_lengths
;
const
auto
&
c_gs_ms_os_lengths
=
problem_descs
[
i
]
.
c_gs_ms_os_lengths
;
const
auto
&
c_gs_ms_os_strides
=
problem_descs
[
i
]
.
c_gs_ms_os_strides
;
const
auto
&
c_gs_ms_os_strides
=
problem_descs
[
i
]
.
c_gs_ms_os_strides
;
...
@@ -339,39 +341,39 @@ int run(int argc, char* argv[])
...
@@ -339,39 +341,39 @@ int run(int argc, char* argv[])
z_gs_ms_ns_device_buf
.
FromDevice
(
z_gs_ms_ns_device_result
.
mData
.
data
());
z_gs_ms_ns_device_buf
.
FromDevice
(
z_gs_ms_ns_device_result
.
mData
.
data
());
lse_gs_ms_device_buf
.
FromDevice
(
lse_gs_ms_device_result
.
mData
.
data
());
lse_gs_ms_device_buf
.
FromDevice
(
lse_gs_ms_device_result
.
mData
.
data
());
Tensor
<
ADataType
>
a_g_m_k
({
G0
*
G1
,
M
,
K
});
Tensor
<
ADataType
>
a_g_m_k
({
G0
*
G1
Q
,
M
,
K
});
Tensor
<
B0DataType
>
b0_g_k_n
({
G0
*
G1
,
K
,
N
});
Tensor
<
B0DataType
>
b0_g_k_n
({
G0
*
G1
Q
,
K
,
N
});
Tensor
<
B1DataType
>
b1_g_n_o
({
G0
*
G1
,
N
,
O
});
Tensor
<
B1DataType
>
b1_g_n_o
({
G0
*
G1
Q
,
N
,
O
});
Tensor
<
AccDataType
>
acc0_g_m_n
({
G0
*
G1
,
M
,
N
});
// scratch object after gemm0
Tensor
<
AccDataType
>
acc0_g_m_n
({
G0
*
G1
Q
,
M
,
N
});
// scratch object after gemm0
Tensor
<
ADataType
>
a1_g_m_n
({
G0
*
G1
,
M
,
N
});
// scratch object after softmax
Tensor
<
ADataType
>
a1_g_m_n
({
G0
*
G1
Q
,
M
,
N
});
// scratch object after softmax
Tensor
<
ADataType
>
a1_g_m_n_drop
({
G0
*
G1
,
M
,
N
});
// scratch object after softmax
Tensor
<
ADataType
>
a1_g_m_n_drop
({
G0
*
G1
Q
,
M
,
N
});
// scratch object after softmax
Tensor
<
CDataType
>
c_g_m_o_host_result
({
G0
*
G1
,
M
,
O
});
// scratch object after gemm1
Tensor
<
CDataType
>
c_g_m_o_host_result
({
G0
*
G1
Q
,
M
,
O
});
// scratch object after gemm1
Tensor
<
CDataType
>
c_gs_ms_os_host_result
(
c_gs_ms_os_lengths
,
c_gs_ms_os_strides
);
Tensor
<
CDataType
>
c_gs_ms_os_host_result
(
c_gs_ms_os_lengths
,
c_gs_ms_os_strides
);
Tensor
<
ZDataType
>
z_g_m_n
({
G0
*
G1
,
M
,
N
});
Tensor
<
ZDataType
>
z_g_m_n
({
G0
*
G1
Q
,
M
,
N
});
Tensor
<
LSEDataType
>
lse_g_m_host_result
({
G0
*
G1
,
M
});
// scratch object after gemm1
Tensor
<
LSEDataType
>
lse_g_m_host_result
({
G0
*
G1
Q
,
M
});
// scratch object after gemm1
Tensor
<
LSEDataType
>
lse_gs_ms_host_result
(
lse_gs_ms_lengths
,
lse_gs_ms_strides
);
Tensor
<
LSEDataType
>
lse_gs_ms_host_result
(
lse_gs_ms_lengths
,
lse_gs_ms_strides
);
// permute
// permute
a_gs_ms_ks
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
a_gs_ms_ks
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
a_g_m_k
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
a_g_m_k
(
idx
[
0
]
*
G1
Q
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
});
});
b0_g_k_n
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
b0_g_k_n
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
const
size_t
&
g0
=
idx
[
0
]
/
G1
;
const
size_t
&
g0
=
idx
[
0
]
/
G1
Q
;
const
size_t
&
g1
=
idx
[
0
]
%
G1
;
const
size_t
&
g1
q
=
idx
[
0
]
%
G1
Q
;
const
size_t
&
g
2
=
g1
/
h_ratio
;
const
size_t
&
g
1kv
=
g1
q
/
h_ratio
;
self
(
idx
)
=
b0_gs_ns_ks
(
g0
,
g
2
,
idx
[
2
],
idx
[
1
]);
self
(
idx
)
=
b0_gs_ns_ks
(
g0
,
g
1kv
,
idx
[
2
],
idx
[
1
]);
});
});
b1_g_n_o
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
b1_g_n_o
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
const
size_t
&
g0
=
idx
[
0
]
/
G1
;
const
size_t
&
g0
=
idx
[
0
]
/
G1
Q
;
const
size_t
&
g1
=
idx
[
0
]
%
G1
;
const
size_t
&
g1
q
=
idx
[
0
]
%
G1
Q
;
const
size_t
&
g
2
=
g1
/
h_ratio
;
const
size_t
&
g
1kv
=
g1
q
/
h_ratio
;
self
(
idx
)
=
b1_gs_os_ns
(
g0
,
g
2
,
idx
[
2
],
idx
[
1
]);
self
(
idx
)
=
b1_gs_os_ns
(
g0
,
g
1kv
,
idx
[
2
],
idx
[
1
]);
});
});
z_gs_ms_ns_device_result
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
z_gs_ms_ns_device_result
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
z_g_m_n
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
z_g_m_n
(
idx
[
0
]
*
G1
Q
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
});
});
// gemm 0
// gemm 0
...
@@ -421,18 +423,18 @@ int run(int argc, char* argv[])
...
@@ -421,18 +423,18 @@ int run(int argc, char* argv[])
// permute
// permute
c_gs_ms_os_host_result
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
c_gs_ms_os_host_result
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
const
size_t
&
g0
=
idx
[
0
];
const
size_t
&
g0
=
idx
[
0
];
const
size_t
&
g1
=
idx
[
1
];
const
size_t
&
g1
q
=
idx
[
1
];
const
size_t
g
=
g0
*
G1
+
g1
;
const
size_t
g
=
g0
*
G1
Q
+
g1
q
;
self
(
idx
)
=
c_g_m_o_host_result
(
g
,
idx
[
2
],
idx
[
3
]);
self
(
idx
)
=
c_g_m_o_host_result
(
g
,
idx
[
2
],
idx
[
3
]);
});
});
lse_gs_ms_host_result
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
lse_gs_ms_host_result
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
const
size_t
&
g0
=
idx
[
0
];
const
size_t
&
g0
=
idx
[
0
];
const
size_t
&
g1
=
idx
[
1
];
const
size_t
&
g1
q
=
idx
[
1
];
const
size_t
g
=
g0
*
G1
+
g1
;
const
size_t
g
=
g0
*
G1
Q
+
g1
q
;
self
(
idx
)
=
lse_g_m_host_result
(
g
,
idx
[
2
]);
self
(
idx
)
=
lse_g_m_host_result
(
g
,
idx
[
2
]);
});
});
...
...
example/52_flash_atten_bias/batched_multihead_attention_bias_backward_v2.cpp
View file @
10836d41
...
@@ -273,15 +273,15 @@ int run(int argc, char* argv[])
...
@@ -273,15 +273,15 @@ int run(int argc, char* argv[])
// Overall QKV matrices shape
// Overall QKV matrices shape
// y_g_m_o = Softmax(alpha * Q_g_m_k * K_g_k_n) * V_g_n_o
// y_g_m_o = Softmax(alpha * Q_g_m_k * K_g_k_n) * V_g_n_o
// y_g0_g1_m_o = reshape(y_g_m_o, [G0, G1, M, O])
// y_g0_g1
q
_m_o = reshape(y_g_m_o, [G0, G1
Q
, M, O])
// y_g0_m_g1_o = permute(y_g0_g1_m_o, [0, 2, 1, 3])
// y_g0_m_g1
q
_o = permute(y_g0_g1
q
_m_o, [0, 2, 1, 3])
ck
::
index_t
M
=
512
;
ck
::
index_t
M
=
512
;
ck
::
index_t
N
=
512
;
ck
::
index_t
N
=
512
;
ck
::
index_t
K
=
DIM
;
ck
::
index_t
K
=
DIM
;
ck
::
index_t
O
=
DIM
;
ck
::
index_t
O
=
DIM
;
ck
::
index_t
G0
=
4
;
ck
::
index_t
G0
=
4
;
ck
::
index_t
G1
=
6
;
// h_q
ck
::
index_t
G1
Q
=
6
;
// h_q
ck
::
index_t
G
2
=
6
;
// h_kv
ck
::
index_t
G
1KV
=
6
;
// h_kv
bool
input_permute
=
false
;
bool
input_permute
=
false
;
bool
output_permute
=
false
;
bool
output_permute
=
false
;
...
@@ -306,13 +306,13 @@ int run(int argc, char* argv[])
...
@@ -306,13 +306,13 @@ int run(int argc, char* argv[])
init_method
=
std
::
stoi
(
argv
[
2
]);
init_method
=
std
::
stoi
(
argv
[
2
]);
time_kernel
=
std
::
stoi
(
argv
[
3
]);
time_kernel
=
std
::
stoi
(
argv
[
3
]);
M
=
std
::
stoi
(
argv
[
4
]);
M
=
std
::
stoi
(
argv
[
4
]);
N
=
std
::
stoi
(
argv
[
5
]);
N
=
std
::
stoi
(
argv
[
5
]);
K
=
std
::
stoi
(
argv
[
6
]);
K
=
std
::
stoi
(
argv
[
6
]);
O
=
std
::
stoi
(
argv
[
7
]);
O
=
std
::
stoi
(
argv
[
7
]);
G0
=
std
::
stoi
(
argv
[
8
]);
G0
=
std
::
stoi
(
argv
[
8
]);
G1
=
std
::
stoi
(
argv
[
9
]);
G1
Q
=
std
::
stoi
(
argv
[
9
]);
G
2
=
std
::
stoi
(
argv
[
10
]);
G
1KV
=
std
::
stoi
(
argv
[
10
]);
p_drop
=
std
::
stof
(
argv
[
11
]);
p_drop
=
std
::
stof
(
argv
[
11
]);
...
@@ -324,7 +324,7 @@ int run(int argc, char* argv[])
...
@@ -324,7 +324,7 @@ int run(int argc, char* argv[])
printf
(
"arg1: verification (0=no, 1=yes)
\n
"
);
printf
(
"arg1: verification (0=no, 1=yes)
\n
"
);
printf
(
"arg2: initialization (0=no init, 1=integer value, 2=decimal value)
\n
"
);
printf
(
"arg2: initialization (0=no init, 1=integer value, 2=decimal value)
\n
"
);
printf
(
"arg3: time kernel (0=no, 1=yes)
\n
"
);
printf
(
"arg3: time kernel (0=no, 1=yes)
\n
"
);
printf
(
"arg4 to 10: M, N, K, O, G0, G1, G
2
\n
"
);
printf
(
"arg4 to 10: M, N, K, O, G0, G1
Q
, G
1KV
\n
"
);
printf
(
"arg11: p_drop
\n
"
);
printf
(
"arg11: p_drop
\n
"
);
printf
(
"arg12 to 13: input / output permute
\n
"
);
printf
(
"arg12 to 13: input / output permute
\n
"
);
exit
(
0
);
exit
(
0
);
...
@@ -343,8 +343,8 @@ int run(int argc, char* argv[])
...
@@ -343,8 +343,8 @@ int run(int argc, char* argv[])
std
::
cout
<<
"K: "
<<
K
<<
std
::
endl
;
std
::
cout
<<
"K: "
<<
K
<<
std
::
endl
;
std
::
cout
<<
"O: "
<<
O
<<
std
::
endl
;
std
::
cout
<<
"O: "
<<
O
<<
std
::
endl
;
std
::
cout
<<
"G0: "
<<
G0
<<
std
::
endl
;
std
::
cout
<<
"G0: "
<<
G0
<<
std
::
endl
;
std
::
cout
<<
"G1: "
<<
G1
<<
std
::
endl
;
std
::
cout
<<
"G1
Q
: "
<<
G1
Q
<<
std
::
endl
;
std
::
cout
<<
"G
2
: "
<<
G
2
<<
std
::
endl
;
std
::
cout
<<
"G
1KV
: "
<<
G
1KV
<<
std
::
endl
;
std
::
cout
<<
"alpha: "
<<
alpha
<<
std
::
endl
;
std
::
cout
<<
"alpha: "
<<
alpha
<<
std
::
endl
;
std
::
cout
<<
"input_permute: "
<<
input_permute
<<
std
::
endl
;
std
::
cout
<<
"input_permute: "
<<
input_permute
<<
std
::
endl
;
std
::
cout
<<
"output_permute: "
<<
output_permute
<<
std
::
endl
;
std
::
cout
<<
"output_permute: "
<<
output_permute
<<
std
::
endl
;
...
@@ -352,63 +352,63 @@ int run(int argc, char* argv[])
...
@@ -352,63 +352,63 @@ int run(int argc, char* argv[])
std
::
cout
<<
"seed: "
<<
seed
<<
std
::
endl
;
std
::
cout
<<
"seed: "
<<
seed
<<
std
::
endl
;
std
::
cout
<<
"offset: "
<<
offset
<<
std
::
endl
;
std
::
cout
<<
"offset: "
<<
offset
<<
std
::
endl
;
const
ck
::
index_t
BatchCount
=
G0
*
G1
;
const
ck
::
index_t
BatchCount
=
G0
*
G1
Q
;
std
::
vector
<
ck
::
index_t
>
q_gs_ms_ks_lengths
{
G0
,
G1
,
M
,
K
};
std
::
vector
<
ck
::
index_t
>
q_gs_ms_ks_lengths
{
G0
,
G1
Q
,
M
,
K
};
std
::
vector
<
ck
::
index_t
>
q_gs_ms_ks_strides
=
std
::
vector
<
ck
::
index_t
>
q_gs_ms_ks_strides
=
input_permute
input_permute
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
*
K
,
K
,
G1
*
K
,
1
}
// Q layout [G0, M, G1, K]
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
Q
*
K
,
K
,
G1
Q
*
K
,
1
}
// Q layout [G0, M, G1
Q
, K]
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
M
*
K
,
M
*
K
,
K
,
1
};
// Q layout [G0, G1, M, K]
:
std
::
vector
<
ck
::
index_t
>
{
G1
Q
*
M
*
K
,
M
*
K
,
K
,
1
};
// Q layout [G0, G1
Q
, M, K]
std
::
vector
<
ck
::
index_t
>
k_gs_ns_ks_lengths
{
G0
,
G
2
,
N
,
K
};
std
::
vector
<
ck
::
index_t
>
k_gs_ns_ks_lengths
{
G0
,
G
1KV
,
N
,
K
};
std
::
vector
<
ck
::
index_t
>
k_gs_ns_ks_strides
=
std
::
vector
<
ck
::
index_t
>
k_gs_ns_ks_strides
=
input_permute
input_permute
?
std
::
vector
<
ck
::
index_t
>
{
N
*
G
2
*
K
,
K
,
G
2
*
K
,
1
}
// K layout [G0, N, G
2
, K]
?
std
::
vector
<
ck
::
index_t
>
{
N
*
G
1KV
*
K
,
K
,
G
1KV
*
K
,
1
}
// K layout [G0, N, G
1KV
, K]
:
std
::
vector
<
ck
::
index_t
>
{
G
2
*
N
*
K
,
N
*
K
,
K
,
1
};
// K layout [G0, G
2
, N, K]
:
std
::
vector
<
ck
::
index_t
>
{
G
1KV
*
N
*
K
,
N
*
K
,
K
,
1
};
// K layout [G0, G
1KV
, N, K]
std
::
vector
<
ck
::
index_t
>
v_gs_os_ns_lengths
{
G0
,
G
2
,
O
,
N
};
std
::
vector
<
ck
::
index_t
>
v_gs_os_ns_lengths
{
G0
,
G
1KV
,
O
,
N
};
std
::
vector
<
ck
::
index_t
>
v_gs_os_ns_strides
=
std
::
vector
<
ck
::
index_t
>
v_gs_os_ns_strides
=
input_permute
input_permute
?
std
::
vector
<
ck
::
index_t
>
{
N
*
G
2
*
O
,
O
,
1
,
G
2
*
O
}
// V layout [G0, N, G
2
, O]
?
std
::
vector
<
ck
::
index_t
>
{
N
*
G
1KV
*
O
,
O
,
1
,
G
1KV
*
O
}
// V layout [G0, N, G
1KV
, O]
:
std
::
vector
<
ck
::
index_t
>
{
G
2
*
N
*
O
,
N
*
O
,
1
,
O
};
// V layout [G0, G
2
, N, O]
:
std
::
vector
<
ck
::
index_t
>
{
G
1KV
*
N
*
O
,
N
*
O
,
1
,
O
};
// V layout [G0, G
1KV
, N, O]
std
::
vector
<
ck
::
index_t
>
y_gs_ms_os_lengths
{
G0
,
G1
,
M
,
O
};
std
::
vector
<
ck
::
index_t
>
y_gs_ms_os_lengths
{
G0
,
G1
Q
,
M
,
O
};
std
::
vector
<
ck
::
index_t
>
y_gs_ms_os_strides
=
std
::
vector
<
ck
::
index_t
>
y_gs_ms_os_strides
=
output_permute
output_permute
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
*
O
,
O
,
G1
*
O
,
1
}
// Y layout [G0, M, G1, O]
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
Q
*
O
,
O
,
G1
Q
*
O
,
1
}
// Y layout [G0, M, G1
Q
, O]
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
M
*
O
,
M
*
O
,
O
,
1
};
// Y layout [G0, G1, M, O]
:
std
::
vector
<
ck
::
index_t
>
{
G1
Q
*
M
*
O
,
M
*
O
,
O
,
1
};
// Y layout [G0, G1
Q
, M, O]
std
::
vector
<
ck
::
index_t
>
d0_gs_ms_ns_lengths
{
G0
,
G1
,
M
,
N
};
std
::
vector
<
ck
::
index_t
>
d0_gs_ms_ns_lengths
{
G0
,
G1
Q
,
M
,
N
};
std
::
vector
<
ck
::
index_t
>
d0_gs_ms_ns_strides
=
std
::
vector
<
ck
::
index_t
>
d0_gs_ms_ns_strides
=
input_permute
input_permute
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
*
N
,
N
,
G1
*
N
,
1
}
// D layout [G0, M, G1, N]
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
Q
*
N
,
N
,
G1
Q
*
N
,
1
}
// D layout [G0, M, G1
Q
, N]
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
M
*
N
,
M
*
N
,
N
,
1
};
// D layout [G0, G1, M, N]
:
std
::
vector
<
ck
::
index_t
>
{
G1
Q
*
M
*
N
,
M
*
N
,
N
,
1
};
// D layout [G0, G1
Q
, M, N]
std
::
vector
<
ck
::
index_t
>
z_gs_ms_ns_lengths
{
G0
,
G1
,
M
,
N
};
std
::
vector
<
ck
::
index_t
>
z_gs_ms_ns_lengths
{
G0
,
G1
Q
,
M
,
N
};
std
::
vector
<
ck
::
index_t
>
z_gs_ms_ns_strides
=
std
::
vector
<
ck
::
index_t
>
z_gs_ms_ns_strides
=
input_permute
input_permute
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
*
N
,
N
,
G1
*
N
,
1
}
// Z layout [G0, M, G1, N]
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
Q
*
N
,
N
,
G1
Q
*
N
,
1
}
// Z layout [G0, M, G1
Q
, N]
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
M
*
N
,
M
*
N
,
N
,
1
};
// Z layout [G0, G1, M, N]
:
std
::
vector
<
ck
::
index_t
>
{
G1
Q
*
M
*
N
,
M
*
N
,
N
,
1
};
// Z layout [G0, G1
Q
, M, N]
std
::
vector
<
ck
::
index_t
>
kgrad_gs_ns_ks_lengths
{
G0
,
G1
,
N
,
K
};
std
::
vector
<
ck
::
index_t
>
kgrad_gs_ns_ks_lengths
{
G0
,
G1
Q
,
N
,
K
};
std
::
vector
<
ck
::
index_t
>
kgrad_gs_ns_ks_strides
=
std
::
vector
<
ck
::
index_t
>
kgrad_gs_ns_ks_strides
=
input_permute
input_permute
?
std
::
vector
<
ck
::
index_t
>
{
N
*
G1
*
K
,
K
,
G1
*
K
,
1
}
// KGrad layout [G0, N, G1, K]
?
std
::
vector
<
ck
::
index_t
>
{
N
*
G1
Q
*
K
,
K
,
G1
Q
*
K
,
1
}
// KGrad layout [G0, N, G1
Q
, K]
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
N
*
K
,
N
*
K
,
K
,
1
};
// KGrad layout [G0, G1, N, K]
:
std
::
vector
<
ck
::
index_t
>
{
G1
Q
*
N
*
K
,
N
*
K
,
K
,
1
};
// KGrad layout [G0, G1
Q
, N, K]
std
::
vector
<
ck
::
index_t
>
vgrad_gs_os_ns_lengths
{
G0
,
G1
,
O
,
N
};
std
::
vector
<
ck
::
index_t
>
vgrad_gs_os_ns_lengths
{
G0
,
G1
Q
,
O
,
N
};
std
::
vector
<
ck
::
index_t
>
vgrad_gs_os_ns_strides
=
std
::
vector
<
ck
::
index_t
>
vgrad_gs_os_ns_strides
=
input_permute
input_permute
?
std
::
vector
<
ck
::
index_t
>
{
N
*
G1
*
O
,
O
,
1
,
G1
*
O
}
// VGrad layout [G0, N, G1, O]
?
std
::
vector
<
ck
::
index_t
>
{
N
*
G1
Q
*
O
,
O
,
1
,
G1
Q
*
O
}
// VGrad layout [G0, N, G1
Q
, O]
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
N
*
O
,
N
*
O
,
1
,
O
};
// VGrad layout [G0, G1, N, O]
:
std
::
vector
<
ck
::
index_t
>
{
G1
Q
*
N
*
O
,
N
*
O
,
1
,
O
};
// VGrad layout [G0, G1
Q
, N, O]
// The softmax stat log-sum-exp (LSE) is used to speed up softmax calculation in backward pass
// The softmax stat log-sum-exp (LSE) is used to speed up softmax calculation in backward pass
// Pi = exp(Si) / sum(exp(S0) + exp(S1) + ...)
// Pi = exp(Si) / sum(exp(S0) + exp(S1) + ...)
// = exp(Si) / exp(log(sum(exp() + ...)))
// = exp(Si) / exp(log(sum(exp() + ...)))
// = exp(Si - log(sum(exp() + ...)))
// = exp(Si - log(sum(exp() + ...)))
// ^^^^^^^^^^^^^^^^^^^^^
// ^^^^^^^^^^^^^^^^^^^^^
// LSE
// LSE
std
::
vector
<
ck
::
index_t
>
lse_gs_ms_lengths
{
G0
,
G1
,
M
};
std
::
vector
<
ck
::
index_t
>
lse_gs_ms_lengths
{
G0
,
G1
Q
,
M
};
std
::
vector
<
ck
::
index_t
>
lse_gs_ms_strides
{
G1
*
M
,
M
,
1
};
// LSE layout [G0, G1, M]
std
::
vector
<
ck
::
index_t
>
lse_gs_ms_strides
{
G1
Q
*
M
,
M
,
1
};
// LSE layout [G0, G1
Q
, M]
Tensor
<
InputDataType
>
q_gs_ms_ks
(
q_gs_ms_ks_lengths
,
q_gs_ms_ks_strides
);
Tensor
<
InputDataType
>
q_gs_ms_ks
(
q_gs_ms_ks_lengths
,
q_gs_ms_ks_strides
);
Tensor
<
InputDataType
>
k_gs_ns_ks
(
k_gs_ns_ks_lengths
,
k_gs_ns_ks_strides
);
Tensor
<
InputDataType
>
k_gs_ns_ks
(
k_gs_ns_ks_lengths
,
k_gs_ns_ks_strides
);
...
@@ -467,7 +467,7 @@ int run(int argc, char* argv[])
...
@@ -467,7 +467,7 @@ int run(int argc, char* argv[])
q_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_1
<
InputDataType
>
{
1
});
q_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_1
<
InputDataType
>
{
1
});
k_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
InputDataType
>
{});
k_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
InputDataType
>
{});
v_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
InputDataType
>
{});
v_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
InputDataType
>
{});
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_Sequential
<
2
>
{});
// dy[g0, g1, m, o]
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_Sequential
<
2
>
{});
// dy[g0, g1
q
, m, o]
d0_gs_ms_ns
.
GenerateTensorValue
(
GeneratorTensor_1
<
Acc0BiasDataType
>
{
1
});
d0_gs_ms_ns
.
GenerateTensorValue
(
GeneratorTensor_1
<
Acc0BiasDataType
>
{
1
});
// dO dot O = [0; 1; 2; ...]
// dO dot O = [0; 1; 2; ...]
break
;
break
;
...
@@ -475,7 +475,7 @@ int run(int argc, char* argv[])
...
@@ -475,7 +475,7 @@ int run(int argc, char* argv[])
q_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_1
<
InputDataType
>
{
1
});
q_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_1
<
InputDataType
>
{
1
});
k_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
InputDataType
>
{});
k_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
InputDataType
>
{});
v_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
InputDataType
>
{});
v_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
InputDataType
>
{});
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_Sequential
<
3
>
{});
// dy[g0, g1, m, o]
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_Sequential
<
3
>
{});
// dy[g0, g1
q
, m, o]
d0_gs_ms_ns
.
GenerateTensorValue
(
GeneratorTensor_1
<
Acc0BiasDataType
>
{
1
});
d0_gs_ms_ns
.
GenerateTensorValue
(
GeneratorTensor_1
<
Acc0BiasDataType
>
{
1
});
// assume mnko = 256
// assume mnko = 256
// P = softmax(QK) = 0.0039 * ones
// P = softmax(QK) = 0.0039 * ones
...
@@ -489,7 +489,8 @@ int run(int argc, char* argv[])
...
@@ -489,7 +489,8 @@ int run(int argc, char* argv[])
q_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_1
<
InputDataType
>
{
1
});
q_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_1
<
InputDataType
>
{
1
});
k_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
InputDataType
>
{});
k_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
InputDataType
>
{});
v_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
InputDataType
>
{});
v_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
InputDataType
>
{});
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_1
<
InputDataType
>
{
1
});
// dy[g0,g1, m, o]
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_1
<
InputDataType
>
{
1
});
// dy[g0, g1q, m, o]
d0_gs_ms_ns
.
GenerateTensorValue
(
GeneratorTensor_1
<
Acc0BiasDataType
>
{
1
});
d0_gs_ms_ns
.
GenerateTensorValue
(
GeneratorTensor_1
<
Acc0BiasDataType
>
{
1
});
// assume mnko = 256
// assume mnko = 256
// P = softmax(QK) = 0.0039 * ones
// P = softmax(QK) = 0.0039 * ones
...
@@ -651,7 +652,7 @@ int run(int argc, char* argv[])
...
@@ -651,7 +652,7 @@ int run(int argc, char* argv[])
Tensor
<
InputDataType
>
q_g_m_k
({
BatchCount
,
M
,
K
});
Tensor
<
InputDataType
>
q_g_m_k
({
BatchCount
,
M
,
K
});
Tensor
<
InputDataType
>
k_g_n_k
({
BatchCount
,
N
,
K
});
Tensor
<
InputDataType
>
k_g_n_k
({
BatchCount
,
N
,
K
});
Tensor
<
Acc0BiasDataType
>
d0_g_m_n
({
G0
*
G1
,
M
,
N
});
Tensor
<
Acc0BiasDataType
>
d0_g_m_n
({
BatchCount
,
M
,
N
});
Tensor
<
ZDataType
>
z_g_m_n
({
BatchCount
,
M
,
N
});
Tensor
<
ZDataType
>
z_g_m_n
({
BatchCount
,
M
,
N
});
Tensor
<
InputDataType
>
v_g_n_o
({
BatchCount
,
N
,
O
});
Tensor
<
InputDataType
>
v_g_n_o
({
BatchCount
,
N
,
O
});
Tensor
<
AccDataType
>
s_g_m_n
({
BatchCount
,
M
,
N
});
Tensor
<
AccDataType
>
s_g_m_n
({
BatchCount
,
M
,
N
});
...
@@ -662,27 +663,27 @@ int run(int argc, char* argv[])
...
@@ -662,27 +663,27 @@ int run(int argc, char* argv[])
z_device_buf
.
FromDevice
(
z_gs_ms_ns
.
mData
.
data
());
z_device_buf
.
FromDevice
(
z_gs_ms_ns
.
mData
.
data
());
z_gs_ms_ns
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
z_gs_ms_ns
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
z_g_m_n
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
z_g_m_n
(
idx
[
0
]
*
G1
Q
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
});
});
q_gs_ms_ks
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
q_gs_ms_ks
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
q_g_m_k
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
q_g_m_k
(
idx
[
0
]
*
G1
Q
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
});
});
k_g_n_k
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
k_g_n_k
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
const
size_t
&
g0
=
idx
[
0
]
/
G1
;
const
size_t
&
g0
=
idx
[
0
]
/
G1
Q
;
const
size_t
&
g1
=
idx
[
0
]
%
G1
;
const
size_t
&
g1
q
=
idx
[
0
]
%
G1
Q
;
const
size_t
&
g
2
=
g1
/
(
G1
/
G
2
);
const
size_t
&
g
1kv
=
g1
q
/
(
G1
Q
/
G
1KV
);
self
(
idx
)
=
k_gs_ns_ks
(
g0
,
g
2
,
idx
[
1
],
idx
[
2
]);
self
(
idx
)
=
k_gs_ns_ks
(
g0
,
g
1kv
,
idx
[
1
],
idx
[
2
]);
});
});
v_g_n_o
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
v_g_n_o
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
const
size_t
&
g0
=
idx
[
0
]
/
G1
;
const
size_t
&
g0
=
idx
[
0
]
/
G1
Q
;
const
size_t
&
g1
=
idx
[
0
]
%
G1
;
const
size_t
&
g1
q
=
idx
[
0
]
%
G1
Q
;
const
size_t
&
g
2
=
g1
/
(
G1
/
G
2
);
const
size_t
&
g
1kv
=
g1
q
/
(
G1
Q
/
G
1KV
);
self
(
idx
)
=
v_gs_os_ns
(
g0
,
g
2
,
idx
[
2
],
idx
[
1
]);
self
(
idx
)
=
v_gs_os_ns
(
g0
,
g
1kv
,
idx
[
2
],
idx
[
1
]);
});
});
d0_gs_ms_ns
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
d0_gs_ms_ns
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
d0_g_m_n
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
d0_g_m_n
(
idx
[
0
]
*
G1
Q
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
});
});
// run fwd again for y, cause z_g_m_n update
// run fwd again for y, cause z_g_m_n update
run_attention_fwd_host
(
q_g_m_k
,
run_attention_fwd_host
(
q_g_m_k
,
...
@@ -699,10 +700,10 @@ int run(int argc, char* argv[])
...
@@ -699,10 +700,10 @@ int run(int argc, char* argv[])
p_dropout_in_uint8_t
,
p_dropout_in_uint8_t
,
rp_dropout
);
rp_dropout
);
y_gs_ms_os
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
y_gs_ms_os
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
self
(
idx
)
=
y_g_m_o
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
]);
self
(
idx
)
=
y_g_m_o
(
idx
[
0
]
*
G1
Q
+
idx
[
1
],
idx
[
2
],
idx
[
3
]);
});
});
lse_gs_ms
.
ForEach
(
lse_gs_ms
.
ForEach
(
[
&
](
auto
&
self
,
auto
idx
)
{
self
(
idx
)
=
lse_g_m
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
]);
});
[
&
](
auto
&
self
,
auto
idx
)
{
self
(
idx
)
=
lse_g_m
(
idx
[
0
]
*
G1
Q
+
idx
[
1
],
idx
[
2
]);
});
y_device_buf
.
ToDevice
(
y_gs_ms_os
.
mData
.
data
());
y_device_buf
.
ToDevice
(
y_gs_ms_os
.
mData
.
data
());
lse_device_buf
.
ToDevice
(
lse_gs_ms
.
mData
.
data
());
lse_device_buf
.
ToDevice
(
lse_gs_ms
.
mData
.
data
());
...
@@ -720,7 +721,7 @@ int run(int argc, char* argv[])
...
@@ -720,7 +721,7 @@ int run(int argc, char* argv[])
Tensor
<
InputDataType
>
ygrad_dot_y_g_m
({
BatchCount
,
M
});
Tensor
<
InputDataType
>
ygrad_dot_y_g_m
({
BatchCount
,
M
});
ygrad_gs_ms_os
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
ygrad_gs_ms_os
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
ygrad_g_m_o
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
ygrad_g_m_o
(
idx
[
0
]
*
G1
Q
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
});
});
#if PRINT_HOST
#if PRINT_HOST
...
@@ -844,35 +845,35 @@ int run(int argc, char* argv[])
...
@@ -844,35 +845,35 @@ int run(int argc, char* argv[])
// permute
// permute
qgrad_gs_ms_ks_host_result
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
qgrad_gs_ms_ks_host_result
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
const
size_t
&
g0
=
idx
[
0
];
const
size_t
&
g0
=
idx
[
0
];
const
size_t
&
g1
=
idx
[
1
];
const
size_t
&
g1
q
=
idx
[
1
];
const
size_t
g
=
g0
*
G1
+
g1
;
const
size_t
g
=
g0
*
G1
Q
+
g1
q
;
self
(
idx
)
=
qgrad_g_m_k
(
g
,
idx
[
2
],
idx
[
3
]);
self
(
idx
)
=
qgrad_g_m_k
(
g
,
idx
[
2
],
idx
[
3
]);
});
});
kgrad_gs_ns_ks_host_result
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
kgrad_gs_ns_ks_host_result
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
const
size_t
&
g0
=
idx
[
0
];
const
size_t
&
g0
=
idx
[
0
];
const
size_t
&
g1
=
idx
[
1
];
const
size_t
&
g1
q
=
idx
[
1
];
const
size_t
g
=
g0
*
G1
+
g1
;
const
size_t
g
=
g0
*
G1
Q
+
g1
q
;
self
(
idx
)
=
kgrad_g_n_k
(
g
,
idx
[
2
],
idx
[
3
]);
self
(
idx
)
=
kgrad_g_n_k
(
g
,
idx
[
2
],
idx
[
3
]);
});
});
vgrad_gs_os_ns_host_result
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
vgrad_gs_os_ns_host_result
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
const
size_t
&
g0
=
idx
[
0
];
const
size_t
&
g0
=
idx
[
0
];
const
size_t
&
g1
=
idx
[
1
];
const
size_t
&
g1
q
=
idx
[
1
];
const
size_t
g
=
g0
*
G1
+
g1
;
const
size_t
g
=
g0
*
G1
Q
+
g1
q
;
self
(
idx
)
=
vgrad_g_n_o
(
g
,
idx
[
3
],
idx
[
2
]);
self
(
idx
)
=
vgrad_g_n_o
(
g
,
idx
[
3
],
idx
[
2
]);
});
});
d0grad_gs_ms_ns_host_result
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
d0grad_gs_ms_ns_host_result
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
const
size_t
&
g0
=
idx
[
0
];
const
size_t
&
g0
=
idx
[
0
];
const
size_t
&
g1
=
idx
[
1
];
const
size_t
&
g1
q
=
idx
[
1
];
const
size_t
g
=
g0
*
G1
+
g1
;
const
size_t
g
=
g0
*
G1
Q
+
g1
q
;
self
(
idx
)
=
sgrad_g_m_n
(
g
,
idx
[
2
],
idx
[
3
]);
self
(
idx
)
=
sgrad_g_m_n
(
g
,
idx
[
2
],
idx
[
3
]);
});
});
...
...
example/52_flash_atten_bias/grouped_multihead_attention_bias_backward_v2.cpp
View file @
10836d41
...
@@ -271,11 +271,11 @@ int run(int argc, char* argv[])
...
@@ -271,11 +271,11 @@ int run(int argc, char* argv[])
// Overall QKV matrices shape
// Overall QKV matrices shape
// y_g_m_o = Softmax(alpha * Q_g_m_k * K_g_k_n) * V_g_n_o
// y_g_m_o = Softmax(alpha * Q_g_m_k * K_g_k_n) * V_g_n_o
// y_g0_g1_m_o = reshape(y_g_m_o, [G0, G1, M, O])
// y_g0_g1
q
_m_o = reshape(y_g_m_o, [G0, G1
Q
, M, O])
// y_g0_m_g1_o = permute(y_g0_g1_m_o, [0, 2, 1, 3])
// y_g0_m_g1
q
_o = permute(y_g0_g1
q
_m_o, [0, 2, 1, 3])
float
alpha
=
1.
f
/
std
::
sqrt
(
DIM
);
float
alpha
=
1.
f
/
std
::
sqrt
(
DIM
);
float
p_drop
=
0.0
;
float
p_drop
=
0.0
;
int
h_ratio
=
1
;
// G1 / G
2
int
h_ratio
=
1
;
// G1
Q
/ G
1KV
bool
input_permute
=
true
;
bool
input_permute
=
true
;
bool
output_permute
=
true
;
bool
output_permute
=
true
;
...
@@ -379,67 +379,71 @@ int run(int argc, char* argv[])
...
@@ -379,67 +379,71 @@ int run(int argc, char* argv[])
std
::
size_t
flop
=
0
,
num_byte
=
0
;
std
::
size_t
flop
=
0
,
num_byte
=
0
;
for
(
std
::
size_t
i
=
0
;
i
<
group_count
;
i
++
)
for
(
std
::
size_t
i
=
0
;
i
<
group_count
;
i
++
)
{
{
int
M
=
128
*
(
rand
()
%
8
)
+
(
rand
()
%
128
);
int
M
=
128
*
(
rand
()
%
8
)
+
(
rand
()
%
128
);
int
N
=
128
*
(
rand
()
%
8
)
+
(
rand
()
%
128
);
int
N
=
128
*
(
rand
()
%
8
)
+
(
rand
()
%
128
);
int
K
=
DIM
;
int
K
=
DIM
;
int
O
=
DIM
;
int
O
=
DIM
;
int
G0
=
rand
()
%
4
+
1
;
int
G0
=
rand
()
%
4
+
1
;
int
G
2
=
rand
()
%
4
+
1
;
int
G
1KV
=
rand
()
%
4
+
1
;
int
G1
=
G
2
*
h_ratio
;
int
G1
Q
=
G
1KV
*
h_ratio
;
std
::
vector
<
ck
::
index_t
>
q_gs_ms_ks_lengths
{
G0
,
G1
,
M
,
K
};
std
::
vector
<
ck
::
index_t
>
q_gs_ms_ks_lengths
{
G0
,
G1
Q
,
M
,
K
};
std
::
vector
<
ck
::
index_t
>
q_gs_ms_ks_strides
=
std
::
vector
<
ck
::
index_t
>
q_gs_ms_ks_strides
=
input_permute
input_permute
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
*
K
,
K
,
G1
*
K
,
1
}
// Q layout [G0, M, G1, K]
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
Q
*
K
,
K
,
G1
Q
*
K
,
1
}
// Q layout [G0, M, G1
Q
, K]
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
M
*
K
,
M
*
K
,
K
,
1
};
// Q layout [G0, G1, M, K]
:
std
::
vector
<
ck
::
index_t
>
{
G1
Q
*
M
*
K
,
M
*
K
,
K
,
1
};
// Q layout [G0, G1
Q
, M, K]
std
::
vector
<
ck
::
index_t
>
k_gs_ns_ks_lengths
{
G0
,
G
2
,
N
,
K
};
std
::
vector
<
ck
::
index_t
>
k_gs_ns_ks_lengths
{
G0
,
G
1KV
,
N
,
K
};
std
::
vector
<
ck
::
index_t
>
k_gs_ns_ks_strides
=
std
::
vector
<
ck
::
index_t
>
k_gs_ns_ks_strides
=
input_permute
input_permute
?
std
::
vector
<
ck
::
index_t
>
{
N
*
G2
*
K
,
K
,
G2
*
K
,
1
}
// K layout [G0, N, G2, K]
?
std
::
vector
<
ck
::
index_t
>
{
N
*
G1KV
*
K
,
K
,
G1KV
*
K
,
1
}
:
std
::
vector
<
ck
::
index_t
>
{
G2
*
N
*
K
,
N
*
K
,
K
,
1
};
// K layout [G0, G2, N, K]
// K layout [G0, N, G1KV, K]
:
std
::
vector
<
ck
::
index_t
>
{
G1KV
*
N
*
K
,
N
*
K
,
K
,
1
};
// K layout [G0, G1KV, N, K]
std
::
vector
<
ck
::
index_t
>
v_gs_os_ns_lengths
{
G0
,
G
2
,
O
,
N
};
std
::
vector
<
ck
::
index_t
>
v_gs_os_ns_lengths
{
G0
,
G
1KV
,
O
,
N
};
std
::
vector
<
ck
::
index_t
>
v_gs_os_ns_strides
=
std
::
vector
<
ck
::
index_t
>
v_gs_os_ns_strides
=
input_permute
input_permute
?
std
::
vector
<
ck
::
index_t
>
{
N
*
G2
*
O
,
O
,
1
,
G2
*
O
}
// V layout [G0, N, G2, O]
?
std
::
vector
<
ck
::
index_t
>
{
N
*
G1KV
*
O
,
O
,
1
,
G1KV
*
O
}
:
std
::
vector
<
ck
::
index_t
>
{
G2
*
N
*
O
,
N
*
O
,
1
,
O
};
// V layout [G0, G2, N, O]
// V layout [G0, N, G1KV, O]
:
std
::
vector
<
ck
::
index_t
>
{
G1KV
*
N
*
O
,
N
*
O
,
1
,
O
};
// V layout [G0, G1KV, N, O]
std
::
vector
<
ck
::
index_t
>
y_gs_ms_os_lengths
{
G0
,
G1
,
M
,
O
};
std
::
vector
<
ck
::
index_t
>
y_gs_ms_os_lengths
{
G0
,
G1
Q
,
M
,
O
};
std
::
vector
<
ck
::
index_t
>
y_gs_ms_os_strides
=
std
::
vector
<
ck
::
index_t
>
y_gs_ms_os_strides
=
output_permute
output_permute
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
*
O
,
O
,
G1
*
O
,
1
}
// Y layout [G0, M, G1, O]
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
Q
*
O
,
O
,
G1
Q
*
O
,
1
}
// Y layout [G0, M, G1
Q
, O]
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
M
*
O
,
M
*
O
,
O
,
1
};
// Y layout [G0, G1, M, O]
:
std
::
vector
<
ck
::
index_t
>
{
G1
Q
*
M
*
O
,
M
*
O
,
O
,
1
};
// Y layout [G0, G1
Q
, M, O]
std
::
vector
<
ck
::
index_t
>
d0_gs_ms_ns_lengths
{
G0
,
G1
,
M
,
N
};
std
::
vector
<
ck
::
index_t
>
d0_gs_ms_ns_lengths
{
G0
,
G1
Q
,
M
,
N
};
std
::
vector
<
ck
::
index_t
>
d0_gs_ms_ns_strides
=
std
::
vector
<
ck
::
index_t
>
d0_gs_ms_ns_strides
=
input_permute
input_permute
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
*
N
,
N
,
G1
*
N
,
1
}
// d0 layout [G0, M, G1, N]
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
Q
*
N
,
N
,
G1
Q
*
N
,
1
}
// d0 layout [G0, M, G1
Q
, N]
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
M
*
N
,
M
*
N
,
N
,
1
};
// d0 layout [G0, G1, M, N]
:
std
::
vector
<
ck
::
index_t
>
{
G1
Q
*
M
*
N
,
M
*
N
,
N
,
1
};
// d0 layout [G0, G1
Q
, M, N]
std
::
vector
<
ck
::
index_t
>
z_gs_ms_ns_lengths
{
G0
,
G1
,
M
,
N
};
std
::
vector
<
ck
::
index_t
>
z_gs_ms_ns_lengths
{
G0
,
G1
Q
,
M
,
N
};
std
::
vector
<
ck
::
index_t
>
z_gs_ms_ns_strides
=
std
::
vector
<
ck
::
index_t
>
z_gs_ms_ns_strides
=
input_permute
input_permute
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
*
N
,
N
,
G1
*
N
,
1
}
// Z layout [G0, M, G1, N]
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
Q
*
N
,
N
,
G1
Q
*
N
,
1
}
// Z layout [G0, M, G1
Q
, N]
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
M
*
N
,
M
*
N
,
N
,
1
};
// Z layout [G0, G1, M, N]
:
std
::
vector
<
ck
::
index_t
>
{
G1
Q
*
M
*
N
,
M
*
N
,
N
,
1
};
// Z layout [G0, G1
Q
, M, N]
std
::
vector
<
ck
::
index_t
>
kgrad_gs_ns_ks_lengths
{
G0
,
G1
,
N
,
K
};
std
::
vector
<
ck
::
index_t
>
kgrad_gs_ns_ks_lengths
{
G0
,
G1
Q
,
N
,
K
};
std
::
vector
<
ck
::
index_t
>
kgrad_gs_ns_ks_strides
=
std
::
vector
<
ck
::
index_t
>
kgrad_gs_ns_ks_strides
=
input_permute
input_permute
?
std
::
vector
<
ck
::
index_t
>
{
N
*
G1Q
*
K
,
K
,
G1Q
*
K
,
1
}
?
std
::
vector
<
ck
::
index_t
>
{
N
*
G1
*
K
,
K
,
G1
*
K
,
1
}
// KGrad layout [G0, N, G1, K]
// KGrad layout [G0, N, G1Q, K]
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
N
*
K
,
N
*
K
,
K
,
1
};
// KGrad layout [G0, G1, N, K]
:
std
::
vector
<
ck
::
index_t
>
{
G1Q
*
N
*
K
,
N
*
K
,
K
,
1
};
// KGrad layout [G0, G1Q, N, K]
std
::
vector
<
ck
::
index_t
>
vgrad_gs_os_ns_lengths
{
G0
,
G1
,
O
,
N
};
std
::
vector
<
ck
::
index_t
>
vgrad_gs_os_ns_lengths
{
G0
,
G1
Q
,
O
,
N
};
std
::
vector
<
ck
::
index_t
>
vgrad_gs_os_ns_strides
=
std
::
vector
<
ck
::
index_t
>
vgrad_gs_os_ns_strides
=
input_permute
input_permute
?
std
::
vector
<
ck
::
index_t
>
{
N
*
G1Q
*
O
,
O
,
1
,
G1Q
*
O
}
?
std
::
vector
<
ck
::
index_t
>
{
N
*
G1
*
O
,
O
,
1
,
G1
*
O
}
// VGrad layout [G0, N, G1, O]
// VGrad layout [G0, N, G1Q, O]
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
N
*
O
,
N
*
O
,
1
,
O
};
// VGrad layout [G0, G1, N, O]
:
std
::
vector
<
ck
::
index_t
>
{
G1Q
*
N
*
O
,
N
*
O
,
1
,
O
};
// VGrad layout [G0, G1Q, N, O]
// The softmax stat log-sum-exp (LSE) is used to speed up softmax calculation in backward
// The softmax stat log-sum-exp (LSE) is used to speed up softmax calculation in backward
// pass Pi = exp(Si) / sum(exp(S0) + exp(S1) + ...)
// pass Pi = exp(Si) / sum(exp(S0) + exp(S1) + ...)
// = exp(Si) / exp(log(sum(exp() + ...)))
// = exp(Si) / exp(log(sum(exp() + ...)))
// = exp(Si - log(sum(exp() + ...)))
// = exp(Si - log(sum(exp() + ...)))
// ^^^^^^^^^^^^^^^^^^^^^
// ^^^^^^^^^^^^^^^^^^^^^
// LSE
// LSE
std
::
vector
<
ck
::
index_t
>
lse_gs_ms_lengths
{
G0
,
G1
,
M
};
std
::
vector
<
ck
::
index_t
>
lse_gs_ms_lengths
{
G0
,
G1
Q
,
M
};
std
::
vector
<
ck
::
index_t
>
lse_gs_ms_strides
{
G1
*
M
,
M
,
1
};
// LSE layout [G0, G1, M]
std
::
vector
<
ck
::
index_t
>
lse_gs_ms_strides
{
G1
Q
*
M
,
M
,
1
};
// LSE layout [G0, G1
Q
, M]
problem_descs
.
push_back
({
problem_descs
.
push_back
({
q_gs_ms_ks_lengths
,
q_gs_ms_ks_lengths
,
q_gs_ms_ks_strides
,
q_gs_ms_ks_strides
,
...
@@ -463,7 +467,7 @@ int run(int argc, char* argv[])
...
@@ -463,7 +467,7 @@ int run(int argc, char* argv[])
{},
// acc1_bias_gs_ms_os_strides,
{},
// acc1_bias_gs_ms_os_strides,
});
});
int
BatchCount
=
G0
*
G1
;
int
BatchCount
=
G0
*
G1
Q
;
flop
+=
(
size_t
(
3
)
*
M
*
N
*
K
+
size_t
(
2
)
*
M
*
N
*
O
)
*
2
*
BatchCount
;
flop
+=
(
size_t
(
3
)
*
M
*
N
*
K
+
size_t
(
2
)
*
M
*
N
*
O
)
*
2
*
BatchCount
;
// Q/K/V/Y, dQ/dK/dV/dY, LSE
// Q/K/V/Y, dQ/dK/dV/dY, LSE
num_byte
+=
num_byte
+=
...
@@ -532,7 +536,8 @@ int run(int argc, char* argv[])
...
@@ -532,7 +536,8 @@ int run(int argc, char* argv[])
q_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_1
<
InputDataType
>
{
1
});
q_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_1
<
InputDataType
>
{
1
});
k_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
InputDataType
>
{});
k_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
InputDataType
>
{});
v_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
InputDataType
>
{});
v_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
InputDataType
>
{});
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_Sequential
<
2
>
{});
// dy[g0, g1, m, o]
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_Sequential
<
2
>
{});
// dy[g0, g1q, m, o]
d0_gs_ms_ns
.
GenerateTensorValue
(
GeneratorTensor_1
<
Acc0BiasDataType
>
{
1
});
d0_gs_ms_ns
.
GenerateTensorValue
(
GeneratorTensor_1
<
Acc0BiasDataType
>
{
1
});
// dO dot O = [0; 1; 2; ...]
// dO dot O = [0; 1; 2; ...]
break
;
break
;
...
@@ -540,7 +545,8 @@ int run(int argc, char* argv[])
...
@@ -540,7 +545,8 @@ int run(int argc, char* argv[])
q_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_1
<
InputDataType
>
{
1
});
q_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_1
<
InputDataType
>
{
1
});
k_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
InputDataType
>
{});
k_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
InputDataType
>
{});
v_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
InputDataType
>
{});
v_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
InputDataType
>
{});
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_Sequential
<
3
>
{});
// dy[g0, g1, m, o]
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_Sequential
<
3
>
{});
// dy[g0, g1q, m, o]
d0_gs_ms_ns
.
GenerateTensorValue
(
GeneratorTensor_1
<
Acc0BiasDataType
>
{
1
});
d0_gs_ms_ns
.
GenerateTensorValue
(
GeneratorTensor_1
<
Acc0BiasDataType
>
{
1
});
// assume mnko = 256
// assume mnko = 256
// P = softmax(QK) = 0.0039 * ones
// P = softmax(QK) = 0.0039 * ones
...
@@ -555,7 +561,7 @@ int run(int argc, char* argv[])
...
@@ -555,7 +561,7 @@ int run(int argc, char* argv[])
k_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
InputDataType
>
{});
k_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
InputDataType
>
{});
v_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
InputDataType
>
{});
v_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
InputDataType
>
{});
ygrad_gs_ms_os
.
GenerateTensorValue
(
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_1
<
InputDataType
>
{
1
});
// dy[g0, g1, m, o]
GeneratorTensor_1
<
InputDataType
>
{
1
});
// dy[g0, g1
q
, m, o]
d0_gs_ms_ns
.
GenerateTensorValue
(
GeneratorTensor_1
<
Acc0BiasDataType
>
{
1
});
d0_gs_ms_ns
.
GenerateTensorValue
(
GeneratorTensor_1
<
Acc0BiasDataType
>
{
1
});
// assume mnko = 256
// assume mnko = 256
// P = softmax(QK) = 0.0039 * ones
// P = softmax(QK) = 0.0039 * ones
...
@@ -578,24 +584,24 @@ int run(int argc, char* argv[])
...
@@ -578,24 +584,24 @@ int run(int argc, char* argv[])
Tensor
<
InputDataType
>
p_drop_g_m_n
({
BatchCount
,
M
,
N
});
Tensor
<
InputDataType
>
p_drop_g_m_n
({
BatchCount
,
M
,
N
});
q_gs_ms_ks
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
q_gs_ms_ks
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
q_g_m_k
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
q_g_m_k
(
idx
[
0
]
*
G1
Q
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
});
});
k_g_n_k
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
k_g_n_k
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
const
size_t
&
g0
=
idx
[
0
]
/
G1
;
const
size_t
&
g0
=
idx
[
0
]
/
G1
Q
;
const
size_t
&
g1
=
idx
[
0
]
%
G1
;
const
size_t
&
g1
q
=
idx
[
0
]
%
G1
Q
;
const
size_t
&
g
2
=
g1
/
h_ratio
;
const
size_t
&
g
1kv
=
g1
q
/
h_ratio
;
self
(
idx
)
=
k_gs_ns_ks
(
g0
,
g
2
,
idx
[
1
],
idx
[
2
]);
self
(
idx
)
=
k_gs_ns_ks
(
g0
,
g
1kv
,
idx
[
1
],
idx
[
2
]);
});
});
d0_gs_ms_ns
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
d0_gs_ms_ns
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
d0_g_m_n
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
d0_g_m_n
(
idx
[
0
]
*
G1
Q
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
});
});
v_g_n_o
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
v_g_n_o
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
const
size_t
&
g0
=
idx
[
0
]
/
G1
;
const
size_t
&
g0
=
idx
[
0
]
/
G1
Q
;
const
size_t
&
g1
=
idx
[
0
]
%
G1
;
const
size_t
&
g1
q
=
idx
[
0
]
%
G1
Q
;
const
size_t
&
g
2
=
g1
/
h_ratio
;
const
size_t
&
g
1kv
=
g1
q
/
h_ratio
;
self
(
idx
)
=
v_gs_os_ns
(
g0
,
g
2
,
idx
[
2
],
idx
[
1
]);
self
(
idx
)
=
v_gs_os_ns
(
g0
,
g
1kv
,
idx
[
2
],
idx
[
1
]);
});
});
q_g_m_ks
.
push_back
(
q_g_m_k
);
q_g_m_ks
.
push_back
(
q_g_m_k
);
...
@@ -745,11 +751,11 @@ int run(int argc, char* argv[])
...
@@ -745,11 +751,11 @@ int run(int argc, char* argv[])
for
(
std
::
size_t
i
=
0
;
i
<
group_count
;
i
++
)
for
(
std
::
size_t
i
=
0
;
i
<
group_count
;
i
++
)
{
{
int
G1
=
q_tensors
[
i
].
GetLengths
()[
1
];
int
G1
Q
=
q_tensors
[
i
].
GetLengths
()[
1
];
// copy z matirx data form device
// copy z matirx data form device
z_tensors_device
[
i
]
->
FromDevice
(
z_tensors
[
i
].
mData
.
data
());
z_tensors_device
[
i
]
->
FromDevice
(
z_tensors
[
i
].
mData
.
data
());
z_tensors
[
i
].
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
z_tensors
[
i
].
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
z_g_m_ns
[
i
](
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
z_g_m_ns
[
i
](
idx
[
0
]
*
G1
Q
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
});
});
run_attention_fwd_host
(
q_g_m_ks
[
i
],
run_attention_fwd_host
(
q_g_m_ks
[
i
],
k_g_n_ks
[
i
],
k_g_n_ks
[
i
],
...
@@ -766,11 +772,11 @@ int run(int argc, char* argv[])
...
@@ -766,11 +772,11 @@ int run(int argc, char* argv[])
rp_dropout
);
rp_dropout
);
y_tensors
[
i
].
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
y_tensors
[
i
].
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
self
(
idx
)
=
y_g_m_os
[
i
](
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
]);
self
(
idx
)
=
y_g_m_os
[
i
](
idx
[
0
]
*
G1
Q
+
idx
[
1
],
idx
[
2
],
idx
[
3
]);
});
});
y_tensors_device
[
i
]
->
ToDevice
(
y_tensors
[
i
].
data
());
y_tensors_device
[
i
]
->
ToDevice
(
y_tensors
[
i
].
data
());
lse_tensors
[
i
].
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
lse_tensors
[
i
].
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
self
(
idx
)
=
lse_g_ms
[
i
](
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
]);
self
(
idx
)
=
lse_g_ms
[
i
](
idx
[
0
]
*
G1
Q
+
idx
[
1
],
idx
[
2
]);
});
});
lse_tensors_device
[
i
]
->
ToDevice
(
lse_tensors
[
i
].
data
());
lse_tensors_device
[
i
]
->
ToDevice
(
lse_tensors
[
i
].
data
());
qgrad_tensors_device
[
i
]
->
SetZero
();
qgrad_tensors_device
[
i
]
->
SetZero
();
...
@@ -785,12 +791,12 @@ int run(int argc, char* argv[])
...
@@ -785,12 +791,12 @@ int run(int argc, char* argv[])
{
{
int
G0
=
q_tensors
[
i
].
GetLengths
()[
0
];
int
G0
=
q_tensors
[
i
].
GetLengths
()[
0
];
int
G1
=
q_tensors
[
i
].
GetLengths
()[
1
];
int
G1
Q
=
q_tensors
[
i
].
GetLengths
()[
1
];
int
O
=
v_tensors
[
i
].
GetLengths
()[
2
];
int
O
=
v_tensors
[
i
].
GetLengths
()[
2
];
int
N
=
v_tensors
[
i
].
GetLengths
()[
3
];
int
N
=
v_tensors
[
i
].
GetLengths
()[
3
];
int
M
=
q_tensors
[
i
].
GetLengths
()[
2
];
int
M
=
q_tensors
[
i
].
GetLengths
()[
2
];
int
K
=
q_tensors
[
i
].
GetLengths
()[
3
];
int
K
=
q_tensors
[
i
].
GetLengths
()[
3
];
int
BatchCount
=
G0
*
G1
;
int
BatchCount
=
G0
*
G1
Q
;
Tensor
<
OutputDataType
>
qgrad_g_m_k
({
BatchCount
,
M
,
K
});
Tensor
<
OutputDataType
>
qgrad_g_m_k
({
BatchCount
,
M
,
K
});
Tensor
<
OutputDataType
>
kgrad_g_n_k
({
BatchCount
,
N
,
K
});
Tensor
<
OutputDataType
>
kgrad_g_n_k
({
BatchCount
,
N
,
K
});
Tensor
<
OutputDataType
>
vgrad_g_n_o
({
BatchCount
,
N
,
O
});
Tensor
<
OutputDataType
>
vgrad_g_n_o
({
BatchCount
,
N
,
O
});
...
@@ -800,7 +806,7 @@ int run(int argc, char* argv[])
...
@@ -800,7 +806,7 @@ int run(int argc, char* argv[])
Tensor
<
InputDataType
>
ygrad_g_m_o
({
BatchCount
,
M
,
O
});
Tensor
<
InputDataType
>
ygrad_g_m_o
({
BatchCount
,
M
,
O
});
ygrad_tensors
[
i
].
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
ygrad_tensors
[
i
].
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
ygrad_g_m_o
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
ygrad_g_m_o
(
idx
[
0
]
*
G1
Q
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
});
});
auto
ref_gemm0_grad
=
ReferenceGemm0GradInstance
{};
auto
ref_gemm0_grad
=
ReferenceGemm0GradInstance
{};
auto
ref_gemm0_grad_invoker
=
ref_gemm0_grad
.
MakeInvoker
();
auto
ref_gemm0_grad_invoker
=
ref_gemm0_grad
.
MakeInvoker
();
...
@@ -868,34 +874,34 @@ int run(int argc, char* argv[])
...
@@ -868,34 +874,34 @@ int run(int argc, char* argv[])
vgrad_tensors_device
[
i
]
->
FromDevice
(
vgrad_gs_os_ns_device_result
.
data
());
vgrad_tensors_device
[
i
]
->
FromDevice
(
vgrad_gs_os_ns_device_result
.
data
());
// permute
// permute
qgrad_gs_ms_ks_host_result
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
qgrad_gs_ms_ks_host_result
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
const
size_t
&
g0
=
idx
[
0
];
const
size_t
&
g0
=
idx
[
0
];
const
size_t
&
g1
=
idx
[
1
];
const
size_t
&
g1
q
=
idx
[
1
];
const
size_t
g
=
g0
*
G1
+
g1
;
const
size_t
g
=
g0
*
G1
Q
+
g1
q
;
self
(
idx
)
=
qgrad_g_m_k
(
g
,
idx
[
2
],
idx
[
3
]);
self
(
idx
)
=
qgrad_g_m_k
(
g
,
idx
[
2
],
idx
[
3
]);
});
});
kgrad_gs_ns_ks_host_result
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
kgrad_gs_ns_ks_host_result
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
const
size_t
&
g0
=
idx
[
0
];
const
size_t
&
g0
=
idx
[
0
];
const
size_t
&
g1
=
idx
[
1
];
const
size_t
&
g1
q
=
idx
[
1
];
const
size_t
g
=
g0
*
G1
+
g1
;
const
size_t
g
=
g0
*
G1
Q
+
g1
q
;
self
(
idx
)
=
kgrad_g_n_k
(
g
,
idx
[
2
],
idx
[
3
]);
self
(
idx
)
=
kgrad_g_n_k
(
g
,
idx
[
2
],
idx
[
3
]);
});
});
d0grad_gs_ms_ns_host_result
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
d0grad_gs_ms_ns_host_result
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
const
size_t
&
g0
=
idx
[
0
];
const
size_t
&
g0
=
idx
[
0
];
const
size_t
&
g1
=
idx
[
1
];
const
size_t
&
g1
q
=
idx
[
1
];
const
size_t
g
=
g0
*
G1
+
g1
;
const
size_t
g
=
g0
*
G1
Q
+
g1
q
;
self
(
idx
)
=
sgrad_g_m_n
(
g
,
idx
[
2
],
idx
[
3
]);
self
(
idx
)
=
sgrad_g_m_n
(
g
,
idx
[
2
],
idx
[
3
]);
});
});
vgrad_gs_os_ns_host_result
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
vgrad_gs_os_ns_host_result
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
const
size_t
&
g0
=
idx
[
0
];
const
size_t
&
g0
=
idx
[
0
];
const
size_t
&
g1
=
idx
[
1
];
const
size_t
&
g1
q
=
idx
[
1
];
const
size_t
g
=
g0
*
G1
+
g1
;
const
size_t
g
=
g0
*
G1
Q
+
g1
q
;
self
(
idx
)
=
vgrad_g_n_o
(
g
,
idx
[
3
],
idx
[
2
]);
self
(
idx
)
=
vgrad_g_n_o
(
g
,
idx
[
3
],
idx
[
2
]);
});
});
...
...
example/52_flash_atten_bias/run_batched_multihead_attention_bias_forward_v2.inc
View file @
10836d41
...
@@ -14,12 +14,12 @@ int run(int argc, char* argv[])
...
@@ -14,12 +14,12 @@ int run(int argc, char* argv[])
ck
::
index_t
K
=
DIM
;
ck
::
index_t
K
=
DIM
;
ck
::
index_t
O
=
DIM
;
ck
::
index_t
O
=
DIM
;
// Output shape C[G0, M, G1, O]. Batch dim, outer dim, inner dim must match GEMM shape
// Output shape C[G0, M, G1
Q
, O]. Batch dim, outer dim, inner dim must match GEMM shape
// C_g0_g1_m_o = reshape(C_g_m_o, [g0, g1, m, o])
// C_g0_g1
q
_m_o = reshape(C_g_m_o, [g0, g1
q
, m, o])
// C_g0_m_g1_o = permute(C_g0_g1_m_o, [0, 2, 1, 3])
// C_g0_m_g1
q
_o = permute(C_g0_g1
q
_m_o, [0, 2, 1, 3])
ck
::
index_t
G0
=
7
;
ck
::
index_t
G0
=
7
;
ck
::
index_t
G1
=
12
;
// h_q
ck
::
index_t
G1
Q
=
12
;
// h_q
ck
::
index_t
G
2
=
12
;
// h_kv
ck
::
index_t
G
1KV
=
12
;
// h_kv
bool
input_permute
=
false
;
bool
input_permute
=
false
;
bool
output_permute
=
true
;
bool
output_permute
=
true
;
...
@@ -44,13 +44,13 @@ int run(int argc, char* argv[])
...
@@ -44,13 +44,13 @@ int run(int argc, char* argv[])
init_method
=
std
::
stoi
(
argv
[
2
]);
init_method
=
std
::
stoi
(
argv
[
2
]);
time_kernel
=
std
::
stoi
(
argv
[
3
]);
time_kernel
=
std
::
stoi
(
argv
[
3
]);
M
=
std
::
stoi
(
argv
[
4
]);
M
=
std
::
stoi
(
argv
[
4
]);
N
=
std
::
stoi
(
argv
[
5
]);
N
=
std
::
stoi
(
argv
[
5
]);
K
=
std
::
stoi
(
argv
[
6
]);
K
=
std
::
stoi
(
argv
[
6
]);
O
=
std
::
stoi
(
argv
[
7
]);
O
=
std
::
stoi
(
argv
[
7
]);
G0
=
std
::
stoi
(
argv
[
8
]);
G0
=
std
::
stoi
(
argv
[
8
]);
G1
=
std
::
stoi
(
argv
[
9
]);
G1
Q
=
std
::
stoi
(
argv
[
9
]);
G
2
=
std
::
stoi
(
argv
[
10
]);
G
1KV
=
std
::
stoi
(
argv
[
10
]);
p_drop
=
std
::
stof
(
argv
[
11
]);
p_drop
=
std
::
stof
(
argv
[
11
]);
...
@@ -62,7 +62,7 @@ int run(int argc, char* argv[])
...
@@ -62,7 +62,7 @@ int run(int argc, char* argv[])
printf
(
"arg1: verification (0=no, 1=yes)
\n
"
);
printf
(
"arg1: verification (0=no, 1=yes)
\n
"
);
printf
(
"arg2: initialization (0=no init, 1=integer value, 2=decimal value)
\n
"
);
printf
(
"arg2: initialization (0=no init, 1=integer value, 2=decimal value)
\n
"
);
printf
(
"arg3: time kernel (0=no, 1=yes)
\n
"
);
printf
(
"arg3: time kernel (0=no, 1=yes)
\n
"
);
printf
(
"arg4 to 10: M, N, K, O, G0, G1, G
2
\n
"
);
printf
(
"arg4 to 10: M, N, K, O, G0, G1
Q
, G
1KV
\n
"
);
printf
(
"arg11: p_drop
\n
"
);
printf
(
"arg11: p_drop
\n
"
);
printf
(
"arg12 to 13: input / output permute
\n
"
);
printf
(
"arg12 to 13: input / output permute
\n
"
);
exit
(
0
);
exit
(
0
);
...
@@ -73,45 +73,45 @@ int run(int argc, char* argv[])
...
@@ -73,45 +73,45 @@ int run(int argc, char* argv[])
float
rp_dropout
=
1.0
/
p_dropout
;
float
rp_dropout
=
1.0
/
p_dropout
;
float
alpha
=
1.
f
/
std
::
sqrt
(
K
);
float
alpha
=
1.
f
/
std
::
sqrt
(
K
);
std
::
vector
<
ck
::
index_t
>
a_gs_ms_ks_lengths
{
G0
,
G1
,
M
,
K
};
std
::
vector
<
ck
::
index_t
>
a_gs_ms_ks_lengths
{
G0
,
G1
Q
,
M
,
K
};
std
::
vector
<
ck
::
index_t
>
a_gs_ms_ks_strides
=
std
::
vector
<
ck
::
index_t
>
a_gs_ms_ks_strides
=
input_permute
input_permute
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
*
K
,
K
,
G1
*
K
,
1
}
// A layout [G0, M, G1, K]
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
Q
*
K
,
K
,
G1
Q
*
K
,
1
}
// A layout [G0, M, G1
Q
, K]
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
M
*
K
,
M
*
K
,
K
,
1
};
// A layout [G0, G1, M, K]
:
std
::
vector
<
ck
::
index_t
>
{
G1
Q
*
M
*
K
,
M
*
K
,
K
,
1
};
// A layout [G0, G1
Q
, M, K]
std
::
vector
<
ck
::
index_t
>
b0_gs_ns_ks_lengths
{
G0
,
G
2
,
N
,
K
};
std
::
vector
<
ck
::
index_t
>
b0_gs_ns_ks_lengths
{
G0
,
G
1KV
,
N
,
K
};
std
::
vector
<
ck
::
index_t
>
b0_gs_ns_ks_strides
=
std
::
vector
<
ck
::
index_t
>
b0_gs_ns_ks_strides
=
input_permute
input_permute
?
std
::
vector
<
ck
::
index_t
>
{
N
*
G
2
*
K
,
K
,
G
2
*
K
,
1
}
// B0 layout [G0, N, G
2
, K]
?
std
::
vector
<
ck
::
index_t
>
{
N
*
G
1KV
*
K
,
K
,
G
1KV
*
K
,
1
}
// B0 layout [G0, N, G
1KV
, K]
:
std
::
vector
<
ck
::
index_t
>
{
G
2
*
N
*
K
,
N
*
K
,
K
,
1
};
// B0 layout [G0, G
2
, N, K]
:
std
::
vector
<
ck
::
index_t
>
{
G
1KV
*
N
*
K
,
N
*
K
,
K
,
1
};
// B0 layout [G0, G
1KV
, N, K]
std
::
vector
<
ck
::
index_t
>
b1_gs_os_ns_lengths
{
G0
,
G
2
,
O
,
N
};
std
::
vector
<
ck
::
index_t
>
b1_gs_os_ns_lengths
{
G0
,
G
1KV
,
O
,
N
};
std
::
vector
<
ck
::
index_t
>
b1_gs_os_ns_strides
=
std
::
vector
<
ck
::
index_t
>
b1_gs_os_ns_strides
=
input_permute
input_permute
?
std
::
vector
<
ck
::
index_t
>
{
N
*
G
2
*
O
,
O
,
1
,
G
2
*
O
}
// B1 layout [G0, N, G
2
, O]
?
std
::
vector
<
ck
::
index_t
>
{
N
*
G
1KV
*
O
,
O
,
1
,
G
1KV
*
O
}
// B1 layout [G0, N, G
1KV
, O]
:
std
::
vector
<
ck
::
index_t
>
{
G
2
*
N
*
O
,
N
*
O
,
1
,
O
};
// B1 layout [G0, G
2
, N, O]
:
std
::
vector
<
ck
::
index_t
>
{
G
1KV
*
N
*
O
,
N
*
O
,
1
,
O
};
// B1 layout [G0, G
1KV
, N, O]
std
::
vector
<
ck
::
index_t
>
c_gs_ms_os_lengths
{
G0
,
G1
,
M
,
O
};
std
::
vector
<
ck
::
index_t
>
c_gs_ms_os_lengths
{
G0
,
G1
Q
,
M
,
O
};
std
::
vector
<
ck
::
index_t
>
c_gs_ms_os_strides
=
std
::
vector
<
ck
::
index_t
>
c_gs_ms_os_strides
=
output_permute
output_permute
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
*
O
,
O
,
G1
*
O
,
1
}
// C layout [G0, M, G1, O]
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
Q
*
O
,
O
,
G1
Q
*
O
,
1
}
// C layout [G0, M, G1
Q
, O]
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
M
*
O
,
M
*
O
,
O
,
1
};
// C layout [G0, G1, M, O]
:
std
::
vector
<
ck
::
index_t
>
{
G1
Q
*
M
*
O
,
M
*
O
,
O
,
1
};
// C layout [G0, G1
Q
, M, O]
std
::
vector
<
ck
::
index_t
>
d_gs_ms_ns_lengths
{
G0
,
G1
,
M
,
N
};
std
::
vector
<
ck
::
index_t
>
d_gs_ms_ns_lengths
{
G0
,
G1
Q
,
M
,
N
};
std
::
vector
<
ck
::
index_t
>
d_gs_ms_ns_strides
=
std
::
vector
<
ck
::
index_t
>
d_gs_ms_ns_strides
=
input_permute
input_permute
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
*
N
,
N
,
G1
*
N
,
1
}
// D layout [G0, M, G1, N]
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
Q
*
N
,
N
,
G1
Q
*
N
,
1
}
// D layout [G0, M, G1
Q
, N]
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
M
*
N
,
M
*
N
,
N
,
1
};
// D layout [G0, G1, M, N]
:
std
::
vector
<
ck
::
index_t
>
{
G1
Q
*
M
*
N
,
M
*
N
,
N
,
1
};
// D layout [G0, G1
Q
, M, N]
std
::
vector
<
ck
::
index_t
>
z_gs_ms_ns_lengths
{
G0
,
G1
,
M
,
N
};
std
::
vector
<
ck
::
index_t
>
z_gs_ms_ns_lengths
{
G0
,
G1
Q
,
M
,
N
};
std
::
vector
<
ck
::
index_t
>
z_gs_ms_ns_strides
=
std
::
vector
<
ck
::
index_t
>
z_gs_ms_ns_strides
=
input_permute
input_permute
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
*
N
,
N
,
G1
*
N
,
1
}
// Z layout [G0, M, G1, N]
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
Q
*
N
,
N
,
G1
Q
*
N
,
1
}
// Z layout [G0, M, G1
Q
, N]
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
M
*
N
,
M
*
N
,
N
,
1
};
// Z layout [G0, G1, M, N]
:
std
::
vector
<
ck
::
index_t
>
{
G1
Q
*
M
*
N
,
M
*
N
,
N
,
1
};
// Z layout [G0, G1
Q
, M, N]
std
::
vector
<
ck
::
index_t
>
lse_gs_ms_lengths
{
G0
,
G1
,
M
};
std
::
vector
<
ck
::
index_t
>
lse_gs_ms_lengths
{
G0
,
G1
Q
,
M
};
std
::
vector
<
ck
::
index_t
>
lse_gs_ms_strides
=
std
::
vector
<
ck
::
index_t
>
lse_gs_ms_strides
=
std
::
vector
<
ck
::
index_t
>
{
G1
*
M
,
M
,
1
};
// LSE layout [G0, G1, M]
std
::
vector
<
ck
::
index_t
>
{
G1
Q
*
M
,
M
,
1
};
// LSE layout [G0, G1
Q
, M]
Tensor
<
ADataType
>
a_gs_ms_ks
(
a_gs_ms_ks_lengths
,
a_gs_ms_ks_strides
);
Tensor
<
ADataType
>
a_gs_ms_ks
(
a_gs_ms_ks_lengths
,
a_gs_ms_ks_strides
);
Tensor
<
B0DataType
>
b0_gs_ns_ks
(
b0_gs_ns_ks_lengths
,
b0_gs_ns_ks_strides
);
Tensor
<
B0DataType
>
b0_gs_ns_ks
(
b0_gs_ns_ks_lengths
,
b0_gs_ns_ks_strides
);
...
@@ -226,7 +226,7 @@ int run(int argc, char* argv[])
...
@@ -226,7 +226,7 @@ int run(int argc, char* argv[])
return
0
;
return
0
;
}
}
ck
::
index_t
BatchCount
=
G0
*
G1
;
ck
::
index_t
BatchCount
=
G0
*
G1
Q
;
float
ave_time
=
invoker
.
Run
(
argument
,
StreamConfig
{
nullptr
,
time_kernel
});
float
ave_time
=
invoker
.
Run
(
argument
,
StreamConfig
{
nullptr
,
time_kernel
});
...
@@ -314,37 +314,37 @@ int run(int argc, char* argv[])
...
@@ -314,37 +314,37 @@ int run(int argc, char* argv[])
Tensor
<
B1DataType
>
b1_g_n_o
({
BatchCount
,
N
,
O
});
Tensor
<
B1DataType
>
b1_g_n_o
({
BatchCount
,
N
,
O
});
Tensor
<
AccDataType
>
acc0_g_m_n
({
BatchCount
,
M
,
N
});
// scratch object after gemm0
Tensor
<
AccDataType
>
acc0_g_m_n
({
BatchCount
,
M
,
N
});
// scratch object after gemm0
Tensor
<
ADataType
>
a1_g_m_n
({
BatchCount
,
M
,
N
});
// scratch object after softmax
Tensor
<
ADataType
>
a1_g_m_n
({
BatchCount
,
M
,
N
});
// scratch object after softmax
Tensor
<
ADataType
>
a1_g_m_n_drop
({
G0
*
G1
,
M
,
N
});
Tensor
<
ADataType
>
a1_g_m_n_drop
({
BatchCount
,
M
,
N
});
Tensor
<
LSEDataType
>
lse_g_m_host_result
(
Tensor
<
LSEDataType
>
lse_g_m_host_result
(
{
BatchCount
,
M
});
// scratch object after max + ln(sum)
{
BatchCount
,
M
});
// scratch object after max + ln(sum)
Tensor
<
Acc0BiasDataType
>
d_g_m_n
({
G0
*
G1
,
M
,
N
});
Tensor
<
Acc0BiasDataType
>
d_g_m_n
({
BatchCount
,
M
,
N
});
Tensor
<
ZDataType
>
z_g_m_n
({
G0
*
G1
,
M
,
N
});
Tensor
<
ZDataType
>
z_g_m_n
({
BatchCount
,
M
,
N
});
Tensor
<
CDataType
>
c_g_m_o_host_result
({
BatchCount
,
M
,
O
});
// scratch object after gemm1
Tensor
<
CDataType
>
c_g_m_o_host_result
({
BatchCount
,
M
,
O
});
// scratch object after gemm1
// permute
// permute
a_gs_ms_ks
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
a_gs_ms_ks
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
a_g_m_k
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
a_g_m_k
(
idx
[
0
]
*
G1
Q
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
});
});
b0_g_k_n
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
b0_g_k_n
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
const
size_t
&
g0
=
idx
[
0
]
/
G1
;
const
size_t
&
g0
=
idx
[
0
]
/
G1
Q
;
const
size_t
&
g1
=
idx
[
0
]
%
G1
;
const
size_t
&
g1
q
=
idx
[
0
]
%
G1
Q
;
const
size_t
&
g
2
=
g1
/
(
G1
/
G
2
);
const
size_t
&
g
1kv
=
g1
q
/
(
G1
Q
/
G
1KV
);
self
(
idx
)
=
b0_gs_ns_ks
(
g0
,
g
2
,
idx
[
2
],
idx
[
1
]);
self
(
idx
)
=
b0_gs_ns_ks
(
g0
,
g
1kv
,
idx
[
2
],
idx
[
1
]);
});
});
b1_g_n_o
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
b1_g_n_o
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
const
size_t
&
g0
=
idx
[
0
]
/
G1
;
const
size_t
&
g0
=
idx
[
0
]
/
G1
Q
;
const
size_t
&
g1
=
idx
[
0
]
%
G1
;
const
size_t
&
g1
q
=
idx
[
0
]
%
G1
Q
;
const
size_t
&
g
2
=
g1
/
(
G1
/
G
2
);
const
size_t
&
g
1kv
=
g1
q
/
(
G1
Q
/
G
1KV
);
self
(
idx
)
=
b1_gs_os_ns
(
g0
,
g
2
,
idx
[
2
],
idx
[
1
]);
self
(
idx
)
=
b1_gs_os_ns
(
g0
,
g
1kv
,
idx
[
2
],
idx
[
1
]);
});
});
d_gs_ms_ns
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
d_gs_ms_ns
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
d_g_m_n
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
d_g_m_n
(
idx
[
0
]
*
G1
Q
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
});
});
z_gs_ms_ns
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
z_gs_ms_ns
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
z_g_m_n
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
z_g_m_n
(
idx
[
0
]
*
G1
Q
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
});
});
// gemm 0
// gemm 0
...
@@ -394,18 +394,18 @@ int run(int argc, char* argv[])
...
@@ -394,18 +394,18 @@ int run(int argc, char* argv[])
// permute
// permute
c_gs_ms_os_host_result
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
c_gs_ms_os_host_result
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
const
size_t
&
g0
=
idx
[
0
];
const
size_t
&
g0
=
idx
[
0
];
const
size_t
&
g1
=
idx
[
1
];
const
size_t
&
g1
q
=
idx
[
1
];
const
size_t
g
=
g0
*
G1
+
g1
;
const
size_t
g
=
g0
*
G1
Q
+
g1
q
;
self
(
idx
)
=
c_g_m_o_host_result
(
g
,
idx
[
2
],
idx
[
3
]);
self
(
idx
)
=
c_g_m_o_host_result
(
g
,
idx
[
2
],
idx
[
3
]);
});
});
lse_gs_ms_host_result
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
lse_gs_ms_host_result
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
const
size_t
&
g0
=
idx
[
0
];
const
size_t
&
g0
=
idx
[
0
];
const
size_t
&
g1
=
idx
[
1
];
const
size_t
&
g1
q
=
idx
[
1
];
const
size_t
g
=
g0
*
G1
+
g1
;
const
size_t
g
=
g0
*
G1
Q
+
g1
q
;
self
(
idx
)
=
lse_g_m_host_result
(
g
,
idx
[
2
]);
self
(
idx
)
=
lse_g_m_host_result
(
g
,
idx
[
2
]);
});
});
...
...
example/52_flash_atten_bias/run_grouped_multihead_attention_bias_forward_v2.inc
View file @
10836d41
...
@@ -11,7 +11,7 @@ int run(int argc, char* argv[])
...
@@ -11,7 +11,7 @@ int run(int argc, char* argv[])
bool
output_permute
=
true
;
bool
output_permute
=
true
;
float
p_drop
=
0.2
;
float
p_drop
=
0.2
;
int
h_ratio
=
1
;
// G1 / G
2
int
h_ratio
=
1
;
// G1
Q
/ G
1KV
const
unsigned
long
long
seed
=
1
;
const
unsigned
long
long
seed
=
1
;
const
unsigned
long
long
offset
=
0
;
const
unsigned
long
long
offset
=
0
;
...
@@ -65,7 +65,7 @@ int run(int argc, char* argv[])
...
@@ -65,7 +65,7 @@ int run(int argc, char* argv[])
std
::
vector
<
void
*>
p_z
;
// for result verification
std
::
vector
<
void
*>
p_z
;
// for result verification
std
::
vector
<
void
*>
p_z_nullptr
;
// for time test
std
::
vector
<
void
*>
p_z_nullptr
;
// for time test
std
::
vector
<
void
*>
p_lse
;
std
::
vector
<
void
*>
p_lse
;
std
::
vector
<
std
::
vector
<
int
>>
g0_g1_m_n_k_o
;
std
::
vector
<
std
::
vector
<
int
>>
g0_g1
q
_m_n_k_o
;
std
::
vector
<
Tensor
<
ADataType
>>
a_tensors
;
std
::
vector
<
Tensor
<
ADataType
>>
a_tensors
;
std
::
vector
<
Tensor
<
B0DataType
>>
b0_tensors
;
std
::
vector
<
Tensor
<
B0DataType
>>
b0_tensors
;
...
@@ -90,55 +90,57 @@ int run(int argc, char* argv[])
...
@@ -90,55 +90,57 @@ int run(int argc, char* argv[])
for
(
std
::
size_t
i
=
0
;
i
<
group_count
;
i
++
)
for
(
std
::
size_t
i
=
0
;
i
<
group_count
;
i
++
)
{
{
int
M
=
128
*
(
rand
()
%
8
)
+
(
rand
()
%
128
);
int
M
=
128
*
(
rand
()
%
8
)
+
(
rand
()
%
128
);
int
N
=
128
*
(
rand
()
%
8
)
+
(
rand
()
%
128
);
int
N
=
128
*
(
rand
()
%
8
)
+
(
rand
()
%
128
);
int
K
=
DIM
;
int
K
=
DIM
;
int
O
=
DIM
;
int
O
=
DIM
;
int
G0
=
rand
()
%
3
+
1
;
int
G0
=
rand
()
%
3
+
1
;
int
G
2
=
rand
()
%
5
+
1
;
int
G
1KV
=
rand
()
%
5
+
1
;
int
G1
=
G
2
*
h_ratio
;
int
G1
Q
=
G
1KV
*
h_ratio
;
g0_g1_m_n_k_o
.
push_back
({
G0
,
G1
,
M
,
N
,
K
,
O
});
g0_g1
q
_m_n_k_o
.
push_back
({
G0
,
G1
Q
,
M
,
N
,
K
,
O
});
std
::
vector
<
ck
::
index_t
>
a_gs_ms_ks_lengths
{
G0
,
G1
,
M
,
K
};
std
::
vector
<
ck
::
index_t
>
a_gs_ms_ks_lengths
{
G0
,
G1
Q
,
M
,
K
};
std
::
vector
<
ck
::
index_t
>
a_gs_ms_ks_strides
=
std
::
vector
<
ck
::
index_t
>
a_gs_ms_ks_strides
=
input_permute
input_permute
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
*
K
,
K
,
G1
*
K
,
1
}
// A layout [G0, M, G1, K]
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
Q
*
K
,
K
,
G1
Q
*
K
,
1
}
// A layout [G0, M, G1
Q
, K]
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
M
*
K
,
M
*
K
,
K
,
1
};
// A layout [G0, G1, M, K]
:
std
::
vector
<
ck
::
index_t
>
{
G1
Q
*
M
*
K
,
M
*
K
,
K
,
1
};
// A layout [G0, G1
Q
, M, K]
std
::
vector
<
ck
::
index_t
>
b0_gs_ns_ks_lengths
{
G0
,
G
2
,
N
,
K
};
std
::
vector
<
ck
::
index_t
>
b0_gs_ns_ks_lengths
{
G0
,
G
1KV
,
N
,
K
};
std
::
vector
<
ck
::
index_t
>
b0_gs_ns_ks_strides
=
std
::
vector
<
ck
::
index_t
>
b0_gs_ns_ks_strides
=
input_permute
input_permute
?
std
::
vector
<
ck
::
index_t
>
{
N
*
G2
*
K
,
K
,
G2
*
K
,
1
}
// B0 layout [G0, N, G2, K]
?
std
::
vector
<
ck
::
index_t
>
{
N
*
G1KV
*
K
,
K
,
G1KV
*
K
,
1
}
:
std
::
vector
<
ck
::
index_t
>
{
G2
*
N
*
K
,
N
*
K
,
K
,
1
};
// B0 layout [G0, G2, N, K]
// B0 layout [G0, N, G1KV, K]
:
std
::
vector
<
ck
::
index_t
>
{
G1KV
*
N
*
K
,
N
*
K
,
K
,
1
};
// B0 layout [G0, G1KV, N, K]
std
::
vector
<
ck
::
index_t
>
b1_gs_os_ns_lengths
{
G0
,
G
2
,
O
,
N
};
std
::
vector
<
ck
::
index_t
>
b1_gs_os_ns_lengths
{
G0
,
G
1KV
,
O
,
N
};
std
::
vector
<
ck
::
index_t
>
b1_gs_os_ns_strides
=
std
::
vector
<
ck
::
index_t
>
b1_gs_os_ns_strides
=
input_permute
input_permute
?
std
::
vector
<
ck
::
index_t
>
{
N
*
G2
*
O
,
O
,
1
,
G2
*
O
}
// B1 layout [G0, N, G2, O]
?
std
::
vector
<
ck
::
index_t
>
{
N
*
G1KV
*
O
,
O
,
1
,
G1KV
*
O
}
:
std
::
vector
<
ck
::
index_t
>
{
G2
*
N
*
O
,
N
*
O
,
1
,
O
};
// B1 layout [G0, G2, N, O]
// B1 layout [G0, N, G1KV, O]
:
std
::
vector
<
ck
::
index_t
>
{
G1KV
*
N
*
O
,
N
*
O
,
1
,
O
};
// B1 layout [G0, G1KV, N, O]
std
::
vector
<
ck
::
index_t
>
c_gs_ms_os_lengths
{
G0
,
G1
,
M
,
O
};
std
::
vector
<
ck
::
index_t
>
c_gs_ms_os_lengths
{
G0
,
G1
Q
,
M
,
O
};
std
::
vector
<
ck
::
index_t
>
c_gs_ms_os_strides
=
std
::
vector
<
ck
::
index_t
>
c_gs_ms_os_strides
=
output_permute
output_permute
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
*
O
,
O
,
G1
*
O
,
1
}
// C layout [G0, M, G1, O]
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
Q
*
O
,
O
,
G1
Q
*
O
,
1
}
// C layout [G0, M, G1
Q
, O]
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
M
*
O
,
M
*
O
,
O
,
1
};
// C layout [G0, G1, M, O]
:
std
::
vector
<
ck
::
index_t
>
{
G1
Q
*
M
*
O
,
M
*
O
,
O
,
1
};
// C layout [G0, G1
Q
, M, O]
std
::
vector
<
ck
::
index_t
>
d_gs_ms_ns_lengths
{
G0
,
G1
,
M
,
N
};
std
::
vector
<
ck
::
index_t
>
d_gs_ms_ns_lengths
{
G0
,
G1
Q
,
M
,
N
};
std
::
vector
<
ck
::
index_t
>
d_gs_ms_ns_strides
=
std
::
vector
<
ck
::
index_t
>
d_gs_ms_ns_strides
=
input_permute
input_permute
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
*
N
,
N
,
G1
*
N
,
1
}
// D layout [G0, M, G1, N]
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
Q
*
N
,
N
,
G1
Q
*
N
,
1
}
// D layout [G0, M, G1
Q
, N]
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
M
*
N
,
M
*
N
,
N
,
1
};
// D layout [G0, G1, M, N]
:
std
::
vector
<
ck
::
index_t
>
{
G1
Q
*
M
*
N
,
M
*
N
,
N
,
1
};
// D layout [G0, G1
Q
, M, N]
std
::
vector
<
ck
::
index_t
>
z_gs_ms_ns_lengths
{
G0
,
G1
,
M
,
N
};
std
::
vector
<
ck
::
index_t
>
z_gs_ms_ns_lengths
{
G0
,
G1
Q
,
M
,
N
};
std
::
vector
<
ck
::
index_t
>
z_gs_ms_ns_strides
=
std
::
vector
<
ck
::
index_t
>
z_gs_ms_ns_strides
=
input_permute
input_permute
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
*
N
,
N
,
G1
*
N
,
1
}
// Z layout [G0, M, G1, N]
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
Q
*
N
,
N
,
G1
Q
*
N
,
1
}
// Z layout [G0, M, G1
Q
, N]
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
M
*
N
,
M
*
N
,
N
,
1
};
// Z layout [G0, G1, M, N]
:
std
::
vector
<
ck
::
index_t
>
{
G1
Q
*
M
*
N
,
M
*
N
,
N
,
1
};
// Z layout [G0, G1
Q
, M, N]
std
::
vector
<
ck
::
index_t
>
lse_gs_ms_lengths
{
G0
,
G1
,
M
};
std
::
vector
<
ck
::
index_t
>
lse_gs_ms_lengths
{
G0
,
G1
Q
,
M
};
std
::
vector
<
ck
::
index_t
>
lse_gs_ms_strides
=
std
::
vector
<
ck
::
index_t
>
lse_gs_ms_strides
=
std
::
vector
<
ck
::
index_t
>
{
G1
*
M
,
M
,
1
};
// LSE layout [G0, G1, M]
std
::
vector
<
ck
::
index_t
>
{
G1
Q
*
M
,
M
,
1
};
// LSE layout [G0, G1
Q
, M]
problem_descs
.
push_back
({
a_gs_ms_ks_lengths
,
problem_descs
.
push_back
({
a_gs_ms_ks_lengths
,
a_gs_ms_ks_strides
,
a_gs_ms_ks_strides
,
...
@@ -166,7 +168,7 @@ int run(int argc, char* argv[])
...
@@ -166,7 +168,7 @@ int run(int argc, char* argv[])
Tensor
<
ZDataType
>
z_gs_ms_ns
(
z_gs_ms_ns_lengths
,
z_gs_ms_ns_strides
);
Tensor
<
ZDataType
>
z_gs_ms_ns
(
z_gs_ms_ns_lengths
,
z_gs_ms_ns_strides
);
Tensor
<
LSEDataType
>
lse_gs_ms_device_result
(
lse_gs_ms_lengths
,
lse_gs_ms_strides
);
Tensor
<
LSEDataType
>
lse_gs_ms_device_result
(
lse_gs_ms_lengths
,
lse_gs_ms_strides
);
int
Batch
=
G0
*
G1
;
int
Batch
=
G0
*
G1
Q
;
flop
+=
(
size_t
(
M
)
*
N
*
K
*
2
+
size_t
(
M
)
*
N
*
O
*
2
)
*
Batch
;
flop
+=
(
size_t
(
M
)
*
N
*
K
*
2
+
size_t
(
M
)
*
N
*
O
*
2
)
*
Batch
;
num_byte
+=
num_byte
+=
(
sizeof
(
ADataType
)
*
M
*
K
+
sizeof
(
B0DataType
)
*
K
*
N
+
sizeof
(
B1DataType
)
*
N
*
O
+
(
sizeof
(
ADataType
)
*
M
*
K
+
sizeof
(
B0DataType
)
*
K
*
N
+
sizeof
(
B1DataType
)
*
N
*
O
+
...
@@ -356,12 +358,12 @@ int run(int argc, char* argv[])
...
@@ -356,12 +358,12 @@ int run(int argc, char* argv[])
for
(
std
::
size_t
i
=
0
;
i
<
group_count
;
i
++
)
for
(
std
::
size_t
i
=
0
;
i
<
group_count
;
i
++
)
{
{
const
int
&
G0
=
g0_g1_m_n_k_o
[
i
][
0
];
const
int
&
G0
=
g0_g1
q
_m_n_k_o
[
i
][
0
];
const
int
&
G1
=
g0_g1_m_n_k_o
[
i
][
1
];
const
int
&
G1
Q
=
g0_g1
q
_m_n_k_o
[
i
][
1
];
const
int
&
M
=
g0_g1_m_n_k_o
[
i
][
2
];
const
int
&
M
=
g0_g1
q
_m_n_k_o
[
i
][
2
];
const
int
&
N
=
g0_g1_m_n_k_o
[
i
][
3
];
const
int
&
N
=
g0_g1
q
_m_n_k_o
[
i
][
3
];
const
int
&
K
=
g0_g1_m_n_k_o
[
i
][
4
];
const
int
&
K
=
g0_g1
q
_m_n_k_o
[
i
][
4
];
const
int
&
O
=
g0_g1_m_n_k_o
[
i
][
5
];
const
int
&
O
=
g0_g1
q
_m_n_k_o
[
i
][
5
];
const
auto
&
c_gs_ms_os_lengths
=
problem_descs
[
i
]
.
c_gs_ms_os_lengths
;
const
auto
&
c_gs_ms_os_lengths
=
problem_descs
[
i
]
.
c_gs_ms_os_lengths
;
const
auto
&
c_gs_ms_os_strides
=
problem_descs
[
i
]
.
c_gs_ms_os_strides
;
const
auto
&
c_gs_ms_os_strides
=
problem_descs
[
i
]
.
c_gs_ms_os_strides
;
...
@@ -383,43 +385,43 @@ int run(int argc, char* argv[])
...
@@ -383,43 +385,43 @@ int run(int argc, char* argv[])
z_gs_ms_ns_device_buf
.
FromDevice
(
z_gs_ms_ns_device_result
.
mData
.
data
());
z_gs_ms_ns_device_buf
.
FromDevice
(
z_gs_ms_ns_device_result
.
mData
.
data
());
lse_gs_ms_device_buf
.
FromDevice
(
lse_gs_ms_device_result
.
mData
.
data
());
lse_gs_ms_device_buf
.
FromDevice
(
lse_gs_ms_device_result
.
mData
.
data
());
Tensor
<
ADataType
>
a_g_m_k
({
G0
*
G1
,
M
,
K
});
Tensor
<
ADataType
>
a_g_m_k
({
G0
*
G1
Q
,
M
,
K
});
Tensor
<
B0DataType
>
b0_g_k_n
({
G0
*
G1
,
K
,
N
});
Tensor
<
B0DataType
>
b0_g_k_n
({
G0
*
G1
Q
,
K
,
N
});
Tensor
<
B1DataType
>
b1_g_n_o
({
G0
*
G1
,
N
,
O
});
Tensor
<
B1DataType
>
b1_g_n_o
({
G0
*
G1
Q
,
N
,
O
});
Tensor
<
AccDataType
>
acc0_g_m_n
({
G0
*
G1
,
M
,
N
});
// scratch object after gemm0
Tensor
<
AccDataType
>
acc0_g_m_n
({
G0
*
G1
Q
,
M
,
N
});
// scratch object after gemm0
Tensor
<
Acc0BiasDataType
>
d_g_m_n
({
G0
*
G1
,
M
,
N
});
Tensor
<
Acc0BiasDataType
>
d_g_m_n
({
G0
*
G1
Q
,
M
,
N
});
Tensor
<
ADataType
>
a1_g_m_n
({
G0
*
G1
,
M
,
N
});
// scratch object after softmax
Tensor
<
ADataType
>
a1_g_m_n
({
G0
*
G1
Q
,
M
,
N
});
// scratch object after softmax
Tensor
<
ADataType
>
a1_g_m_n_drop
({
G0
*
G1
,
M
,
N
});
// scratch object after softmax
Tensor
<
ADataType
>
a1_g_m_n_drop
({
G0
*
G1
Q
,
M
,
N
});
// scratch object after softmax
Tensor
<
CDataType
>
c_g_m_o_host_result
({
G0
*
G1
,
M
,
O
});
// scratch object after gemm1
Tensor
<
CDataType
>
c_g_m_o_host_result
({
G0
*
G1
Q
,
M
,
O
});
// scratch object after gemm1
Tensor
<
CDataType
>
c_gs_ms_os_host_result
(
c_gs_ms_os_lengths
,
c_gs_ms_os_strides
);
Tensor
<
CDataType
>
c_gs_ms_os_host_result
(
c_gs_ms_os_lengths
,
c_gs_ms_os_strides
);
Tensor
<
ZDataType
>
z_g_m_n
({
G0
*
G1
,
M
,
N
});
Tensor
<
ZDataType
>
z_g_m_n
({
G0
*
G1
Q
,
M
,
N
});
Tensor
<
LSEDataType
>
lse_g_m_host_result
({
G0
*
G1
,
M
});
// scratch object after gemm1
Tensor
<
LSEDataType
>
lse_g_m_host_result
({
G0
*
G1
Q
,
M
});
// scratch object after gemm1
Tensor
<
LSEDataType
>
lse_gs_ms_host_result
(
lse_gs_ms_lengths
,
lse_gs_ms_strides
);
Tensor
<
LSEDataType
>
lse_gs_ms_host_result
(
lse_gs_ms_lengths
,
lse_gs_ms_strides
);
// permute
// permute
a_gs_ms_ks
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
a_gs_ms_ks
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
a_g_m_k
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
a_g_m_k
(
idx
[
0
]
*
G1
Q
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
});
});
b0_g_k_n
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
b0_g_k_n
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
const
size_t
&
g0
=
idx
[
0
]
/
G1
;
const
size_t
&
g0
=
idx
[
0
]
/
G1
Q
;
const
size_t
&
g1
=
idx
[
0
]
%
G1
;
const
size_t
&
g1
q
=
idx
[
0
]
%
G1
Q
;
const
size_t
&
g
2
=
g1
/
h_ratio
;
const
size_t
&
g
1kv
=
g1
q
/
h_ratio
;
self
(
idx
)
=
b0_gs_ns_ks
(
g0
,
g
2
,
idx
[
2
],
idx
[
1
]);
self
(
idx
)
=
b0_gs_ns_ks
(
g0
,
g
1kv
,
idx
[
2
],
idx
[
1
]);
});
});
b1_g_n_o
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
b1_g_n_o
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
const
size_t
&
g0
=
idx
[
0
]
/
G1
;
const
size_t
&
g0
=
idx
[
0
]
/
G1
Q
;
const
size_t
&
g1
=
idx
[
0
]
%
G1
;
const
size_t
&
g1
q
=
idx
[
0
]
%
G1
Q
;
const
size_t
&
g
2
=
g1
/
h_ratio
;
const
size_t
&
g
1kv
=
g1
q
/
h_ratio
;
self
(
idx
)
=
b1_gs_os_ns
(
g0
,
g
2
,
idx
[
2
],
idx
[
1
]);
self
(
idx
)
=
b1_gs_os_ns
(
g0
,
g
1kv
,
idx
[
2
],
idx
[
1
]);
});
});
d_gs_ms_ns
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
d_gs_ms_ns
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
d_g_m_n
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
d_g_m_n
(
idx
[
0
]
*
G1
Q
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
});
});
z_gs_ms_ns_device_result
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
z_gs_ms_ns_device_result
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
z_g_m_n
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
z_g_m_n
(
idx
[
0
]
*
G1
Q
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
});
});
// gemm 0
// gemm 0
...
@@ -473,18 +475,18 @@ int run(int argc, char* argv[])
...
@@ -473,18 +475,18 @@ int run(int argc, char* argv[])
// permute
// permute
c_gs_ms_os_host_result
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
c_gs_ms_os_host_result
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
const
size_t
&
g0
=
idx
[
0
];
const
size_t
&
g0
=
idx
[
0
];
const
size_t
&
g1
=
idx
[
1
];
const
size_t
&
g1
q
=
idx
[
1
];
const
size_t
g
=
g0
*
G1
+
g1
;
const
size_t
g
=
g0
*
G1
Q
+
g1
q
;
self
(
idx
)
=
c_g_m_o_host_result
(
g
,
idx
[
2
],
idx
[
3
]);
self
(
idx
)
=
c_g_m_o_host_result
(
g
,
idx
[
2
],
idx
[
3
]);
});
});
lse_gs_ms_host_result
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
lse_gs_ms_host_result
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
const
size_t
&
g0
=
idx
[
0
];
const
size_t
&
g0
=
idx
[
0
];
const
size_t
&
g1
=
idx
[
1
];
const
size_t
&
g1
q
=
idx
[
1
];
const
size_t
g
=
g0
*
G1
+
g1
;
const
size_t
g
=
g0
*
G1
Q
+
g1
q
;
self
(
idx
)
=
lse_g_m_host_result
(
g
,
idx
[
2
]);
self
(
idx
)
=
lse_g_m_host_result
(
g
,
idx
[
2
]);
});
});
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment