Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
29398e70
Commit
29398e70
authored
Sep 27, 2023
by
danyao12
Browse files
update 52 examples w/ mqa/gqa
parent
617bdf3f
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
190 additions
and
93 deletions
+190
-93
example/52_flash_atten_bias/batched_multihead_attention_bias_backward_v2.cpp
...ten_bias/batched_multihead_attention_bias_backward_v2.cpp
+63
-24
example/52_flash_atten_bias/grouped_multihead_attention_bias_backward_v2.cpp
...ten_bias/grouped_multihead_attention_bias_backward_v2.cpp
+67
-33
example/52_flash_atten_bias/run_batched_multihead_attention_bias_forward_v2.inc
..._bias/run_batched_multihead_attention_bias_forward_v2.inc
+28
-18
example/52_flash_atten_bias/run_grouped_multihead_attention_bias_forward_v2.inc
..._bias/run_grouped_multihead_attention_bias_forward_v2.inc
+32
-18
No files found.
example/52_flash_atten_bias/batched_multihead_attention_bias_backward_v2.cpp
View file @
29398e70
...
@@ -280,7 +280,8 @@ int run(int argc, char* argv[])
...
@@ -280,7 +280,8 @@ int run(int argc, char* argv[])
ck
::
index_t
K
=
DIM
;
ck
::
index_t
K
=
DIM
;
ck
::
index_t
O
=
DIM
;
ck
::
index_t
O
=
DIM
;
ck
::
index_t
G0
=
4
;
ck
::
index_t
G0
=
4
;
ck
::
index_t
G1
=
6
;
ck
::
index_t
G1
=
6
;
// h_q
ck
::
index_t
G2
=
6
;
// h_kv
bool
input_permute
=
false
;
bool
input_permute
=
false
;
bool
output_permute
=
false
;
bool
output_permute
=
false
;
...
@@ -299,7 +300,7 @@ int run(int argc, char* argv[])
...
@@ -299,7 +300,7 @@ int run(int argc, char* argv[])
init_method
=
std
::
stoi
(
argv
[
2
]);
init_method
=
std
::
stoi
(
argv
[
2
]);
time_kernel
=
std
::
stoi
(
argv
[
3
]);
time_kernel
=
std
::
stoi
(
argv
[
3
]);
}
}
else
if
(
argc
==
1
3
)
else
if
(
argc
==
1
4
)
{
{
do_verification
=
std
::
stoi
(
argv
[
1
]);
do_verification
=
std
::
stoi
(
argv
[
1
]);
init_method
=
std
::
stoi
(
argv
[
2
]);
init_method
=
std
::
stoi
(
argv
[
2
]);
...
@@ -311,20 +312,21 @@ int run(int argc, char* argv[])
...
@@ -311,20 +312,21 @@ int run(int argc, char* argv[])
O
=
std
::
stoi
(
argv
[
7
]);
O
=
std
::
stoi
(
argv
[
7
]);
G0
=
std
::
stoi
(
argv
[
8
]);
G0
=
std
::
stoi
(
argv
[
8
]);
G1
=
std
::
stoi
(
argv
[
9
]);
G1
=
std
::
stoi
(
argv
[
9
]);
G2
=
std
::
stoi
(
argv
[
10
]);
p_drop
=
std
::
stof
(
argv
[
1
0
]);
p_drop
=
std
::
stof
(
argv
[
1
1
]);
input_permute
=
std
::
stoi
(
argv
[
1
1
]);
input_permute
=
std
::
stoi
(
argv
[
1
2
]);
output_permute
=
std
::
stoi
(
argv
[
1
2
]);
output_permute
=
std
::
stoi
(
argv
[
1
3
]);
}
}
else
else
{
{
printf
(
"arg1: verification (0=no, 1=yes)
\n
"
);
printf
(
"arg1: verification (0=no, 1=yes)
\n
"
);
printf
(
"arg2: initialization (0=no init, 1=integer value, 2=decimal value)
\n
"
);
printf
(
"arg2: initialization (0=no init, 1=integer value, 2=decimal value)
\n
"
);
printf
(
"arg3: time kernel (0=no, 1=yes)
\n
"
);
printf
(
"arg3: time kernel (0=no, 1=yes)
\n
"
);
printf
(
"arg4 to 1
1
: M, N, K, O, G0, G1
\n
"
);
printf
(
"arg4 to 1
0
: M, N, K, O, G0, G1
, G2
\n
"
);
printf
(
"arg1
0
:
scale (alpha)
\n
"
);
printf
(
"arg1
1
:
p_drop
\n
"
);
printf
(
"arg1
1
to 1
2
: input / output permute
\n
"
);
printf
(
"arg1
2
to 1
3
: input / output permute
\n
"
);
exit
(
0
);
exit
(
0
);
}
}
...
@@ -342,6 +344,7 @@ int run(int argc, char* argv[])
...
@@ -342,6 +344,7 @@ int run(int argc, char* argv[])
std
::
cout
<<
"O: "
<<
O
<<
std
::
endl
;
std
::
cout
<<
"O: "
<<
O
<<
std
::
endl
;
std
::
cout
<<
"G0: "
<<
G0
<<
std
::
endl
;
std
::
cout
<<
"G0: "
<<
G0
<<
std
::
endl
;
std
::
cout
<<
"G1: "
<<
G1
<<
std
::
endl
;
std
::
cout
<<
"G1: "
<<
G1
<<
std
::
endl
;
std
::
cout
<<
"G2: "
<<
G2
<<
std
::
endl
;
std
::
cout
<<
"alpha: "
<<
alpha
<<
std
::
endl
;
std
::
cout
<<
"alpha: "
<<
alpha
<<
std
::
endl
;
std
::
cout
<<
"input_permute: "
<<
input_permute
<<
std
::
endl
;
std
::
cout
<<
"input_permute: "
<<
input_permute
<<
std
::
endl
;
std
::
cout
<<
"output_permute: "
<<
output_permute
<<
std
::
endl
;
std
::
cout
<<
"output_permute: "
<<
output_permute
<<
std
::
endl
;
...
@@ -357,17 +360,17 @@ int run(int argc, char* argv[])
...
@@ -357,17 +360,17 @@ int run(int argc, char* argv[])
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
*
K
,
K
,
G1
*
K
,
1
}
// Q layout [G0, M, G1, K]
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
*
K
,
K
,
G1
*
K
,
1
}
// Q layout [G0, M, G1, K]
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
M
*
K
,
M
*
K
,
K
,
1
};
// Q layout [G0, G1, M, K]
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
M
*
K
,
M
*
K
,
K
,
1
};
// Q layout [G0, G1, M, K]
std
::
vector
<
ck
::
index_t
>
k_gs_ns_ks_lengths
{
G0
,
G
1
,
N
,
K
};
std
::
vector
<
ck
::
index_t
>
k_gs_ns_ks_lengths
{
G0
,
G
2
,
N
,
K
};
std
::
vector
<
ck
::
index_t
>
k_gs_ns_ks_strides
=
std
::
vector
<
ck
::
index_t
>
k_gs_ns_ks_strides
=
input_permute
input_permute
?
std
::
vector
<
ck
::
index_t
>
{
N
*
G
1
*
K
,
K
,
G
1
*
K
,
1
}
// K layout [G0, N, G
1
, K]
?
std
::
vector
<
ck
::
index_t
>
{
N
*
G
2
*
K
,
K
,
G
2
*
K
,
1
}
// K layout [G0, N, G
2
, K]
:
std
::
vector
<
ck
::
index_t
>
{
G
1
*
N
*
K
,
N
*
K
,
K
,
1
};
// K layout [G0, G
1
, N, K]
:
std
::
vector
<
ck
::
index_t
>
{
G
2
*
N
*
K
,
N
*
K
,
K
,
1
};
// K layout [G0, G
2
, N, K]
std
::
vector
<
ck
::
index_t
>
v_gs_os_ns_lengths
{
G0
,
G
1
,
O
,
N
};
std
::
vector
<
ck
::
index_t
>
v_gs_os_ns_lengths
{
G0
,
G
2
,
O
,
N
};
std
::
vector
<
ck
::
index_t
>
v_gs_os_ns_strides
=
std
::
vector
<
ck
::
index_t
>
v_gs_os_ns_strides
=
input_permute
input_permute
?
std
::
vector
<
ck
::
index_t
>
{
N
*
G
1
*
O
,
O
,
1
,
G
1
*
O
}
// V layout [G0, N, G
1
, O]
?
std
::
vector
<
ck
::
index_t
>
{
N
*
G
2
*
O
,
O
,
1
,
G
2
*
O
}
// V layout [G0, N, G
2
, O]
:
std
::
vector
<
ck
::
index_t
>
{
G
1
*
N
*
O
,
N
*
O
,
1
,
O
};
// V layout [G0, G
1
, N, O]
:
std
::
vector
<
ck
::
index_t
>
{
G
2
*
N
*
O
,
N
*
O
,
1
,
O
};
// V layout [G0, G
2
, N, O]
std
::
vector
<
ck
::
index_t
>
y_gs_ms_os_lengths
{
G0
,
G1
,
M
,
O
};
std
::
vector
<
ck
::
index_t
>
y_gs_ms_os_lengths
{
G0
,
G1
,
M
,
O
};
std
::
vector
<
ck
::
index_t
>
y_gs_ms_os_strides
=
std
::
vector
<
ck
::
index_t
>
y_gs_ms_os_strides
=
...
@@ -386,6 +389,18 @@ int run(int argc, char* argv[])
...
@@ -386,6 +389,18 @@ int run(int argc, char* argv[])
input_permute
input_permute
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
*
N
,
N
,
G1
*
N
,
1
}
// Z layout [G0, M, G1, N]
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
*
N
,
N
,
G1
*
N
,
1
}
// Z layout [G0, M, G1, N]
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
M
*
N
,
M
*
N
,
N
,
1
};
// Z layout [G0, G1, M, N]
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
M
*
N
,
M
*
N
,
N
,
1
};
// Z layout [G0, G1, M, N]
std
::
vector
<
ck
::
index_t
>
kgrad_gs_ns_ks_lengths
{
G0
,
G1
,
N
,
K
};
std
::
vector
<
ck
::
index_t
>
kgrad_gs_ns_ks_strides
=
input_permute
?
std
::
vector
<
ck
::
index_t
>
{
N
*
G1
*
K
,
K
,
G1
*
K
,
1
}
// KGrad layout [G0, N, G1, K]
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
N
*
K
,
N
*
K
,
K
,
1
};
// KGrad layout [G0, G1, N, K]
std
::
vector
<
ck
::
index_t
>
vgrad_gs_os_ns_lengths
{
G0
,
G1
,
O
,
N
};
std
::
vector
<
ck
::
index_t
>
vgrad_gs_os_ns_strides
=
input_permute
?
std
::
vector
<
ck
::
index_t
>
{
N
*
G1
*
O
,
O
,
1
,
G1
*
O
}
// VGrad layout [G0, N, G1, O]
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
N
*
O
,
N
*
O
,
1
,
O
};
// VGrad layout [G0, G1, N, O]
// The softmax stat log-sum-exp (LSE) is used to speed up softmax calculation in backward pass
// The softmax stat log-sum-exp (LSE) is used to speed up softmax calculation in backward pass
// Pi = exp(Si) / sum(exp(S0) + exp(S1) + ...)
// Pi = exp(Si) / sum(exp(S0) + exp(S1) + ...)
// = exp(Si) / exp(log(sum(exp() + ...)))
// = exp(Si) / exp(log(sum(exp() + ...)))
...
@@ -403,6 +418,8 @@ int run(int argc, char* argv[])
...
@@ -403,6 +418,8 @@ int run(int argc, char* argv[])
Tensor
<
InputDataType
>
y_gs_ms_os
(
y_gs_ms_os_lengths
,
y_gs_ms_os_strides
);
Tensor
<
InputDataType
>
y_gs_ms_os
(
y_gs_ms_os_lengths
,
y_gs_ms_os_strides
);
Tensor
<
InputDataType
>
ygrad_gs_ms_os
(
y_gs_ms_os_lengths
,
y_gs_ms_os_strides
);
Tensor
<
InputDataType
>
ygrad_gs_ms_os
(
y_gs_ms_os_lengths
,
y_gs_ms_os_strides
);
Tensor
<
LSEDataType
>
lse_gs_ms
(
lse_gs_ms_lengths
,
lse_gs_ms_strides
);
Tensor
<
LSEDataType
>
lse_gs_ms
(
lse_gs_ms_lengths
,
lse_gs_ms_strides
);
Tensor
<
OutputDataType
>
kgrad_gs_ns_ks
(
kgrad_gs_ns_ks_lengths
,
kgrad_gs_ns_ks_strides
);
Tensor
<
OutputDataType
>
vgrad_gs_os_ns
(
vgrad_gs_os_ns_lengths
,
vgrad_gs_os_ns_strides
);
std
::
cout
<<
"q_gs_ms_ks: "
<<
q_gs_ms_ks
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"q_gs_ms_ks: "
<<
q_gs_ms_ks
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"k_gs_ns_ks: "
<<
k_gs_ns_ks
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"k_gs_ns_ks: "
<<
k_gs_ns_ks
.
mDesc
<<
std
::
endl
;
...
@@ -411,6 +428,8 @@ int run(int argc, char* argv[])
...
@@ -411,6 +428,8 @@ int run(int argc, char* argv[])
std
::
cout
<<
"v_gs_os_ns: "
<<
v_gs_os_ns
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"v_gs_os_ns: "
<<
v_gs_os_ns
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"y_gs_ms_os: "
<<
y_gs_ms_os
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"y_gs_ms_os: "
<<
y_gs_ms_os
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"lse_gs_ms_os: "
<<
lse_gs_ms
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"lse_gs_ms_os: "
<<
lse_gs_ms
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"kgrad_gs_ns_ks: "
<<
kgrad_gs_ns_ks
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"vgrad_gs_os_ns: "
<<
vgrad_gs_os_ns
.
mDesc
<<
std
::
endl
;
z_gs_ms_ns
.
GenerateTensorValue
(
GeneratorTensor_1
<
InputDataType
>
{
0
});
z_gs_ms_ns
.
GenerateTensorValue
(
GeneratorTensor_1
<
InputDataType
>
{
0
});
switch
(
init_method
)
switch
(
init_method
)
...
@@ -491,8 +510,8 @@ int run(int argc, char* argv[])
...
@@ -491,8 +510,8 @@ int run(int argc, char* argv[])
DeviceMem
y_device_buf
(
sizeof
(
InputDataType
)
*
y_gs_ms_os
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
y_device_buf
(
sizeof
(
InputDataType
)
*
y_gs_ms_os
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
lse_device_buf
(
sizeof
(
LSEDataType
)
*
lse_gs_ms
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
lse_device_buf
(
sizeof
(
LSEDataType
)
*
lse_gs_ms
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
qgrad_device_buf
(
sizeof
(
OutputDataType
)
*
q_gs_ms_ks
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
qgrad_device_buf
(
sizeof
(
OutputDataType
)
*
q_gs_ms_ks
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
kgrad_device_buf
(
sizeof
(
OutputDataType
)
*
k_gs_ns_ks
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
kgrad_device_buf
(
sizeof
(
OutputDataType
)
*
k
grad
_gs_ns_ks
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
vgrad_device_buf
(
sizeof
(
OutputDataType
)
*
v_gs_os_ns
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
vgrad_device_buf
(
sizeof
(
OutputDataType
)
*
v
grad
_gs_os_ns
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
ygrad_device_buf
(
sizeof
(
InputDataType
)
*
y_gs_ms_os
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
ygrad_device_buf
(
sizeof
(
InputDataType
)
*
y_gs_ms_os
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
d0grad_device_buf
(
sizeof
(
Acc0BiasDataType
)
*
d0_gs_ms_ns
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
d0grad_device_buf
(
sizeof
(
Acc0BiasDataType
)
*
d0_gs_ms_ns
.
mDesc
.
GetElementSpaceSize
());
...
@@ -533,6 +552,10 @@ int run(int argc, char* argv[])
...
@@ -533,6 +552,10 @@ int run(int argc, char* argv[])
y_gs_ms_os_lengths
,
y_gs_ms_os_lengths
,
y_gs_ms_os_strides
,
y_gs_ms_os_strides
,
lse_gs_ms_lengths
,
lse_gs_ms_lengths
,
kgrad_gs_ns_ks_lengths
,
kgrad_gs_ns_ks_strides
,
vgrad_gs_os_ns_lengths
,
vgrad_gs_os_ns_strides
,
d0_gs_ms_ns_lengths
,
// acc0_bias_gs_ms_ns_lengths
d0_gs_ms_ns_lengths
,
// acc0_bias_gs_ms_ns_lengths
d0_gs_ms_ns_strides
,
// acc0_bias_gs_ms_ns_strides
d0_gs_ms_ns_strides
,
// acc0_bias_gs_ms_ns_strides
{},
// acc1_bias_gs_ms_os_lengths,
{},
// acc1_bias_gs_ms_os_lengths,
...
@@ -580,6 +603,10 @@ int run(int argc, char* argv[])
...
@@ -580,6 +603,10 @@ int run(int argc, char* argv[])
y_gs_ms_os_lengths
,
y_gs_ms_os_lengths
,
y_gs_ms_os_strides
,
y_gs_ms_os_strides
,
lse_gs_ms_lengths
,
lse_gs_ms_lengths
,
kgrad_gs_ns_ks_lengths
,
kgrad_gs_ns_ks_strides
,
vgrad_gs_os_ns_lengths
,
vgrad_gs_os_ns_strides
,
d0_gs_ms_ns_lengths
,
// acc0_bias_gs_ms_ns_lengths
d0_gs_ms_ns_lengths
,
// acc0_bias_gs_ms_ns_lengths
d0_gs_ms_ns_strides
,
// acc0_bias_gs_ms_ns_strides
d0_gs_ms_ns_strides
,
// acc0_bias_gs_ms_ns_strides
{},
// acc1_bias_gs_ms_os_lengths,
{},
// acc1_bias_gs_ms_os_lengths,
...
@@ -640,11 +667,19 @@ int run(int argc, char* argv[])
...
@@ -640,11 +667,19 @@ int run(int argc, char* argv[])
q_gs_ms_ks
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
q_gs_ms_ks
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
q_g_m_k
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
q_g_m_k
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
});
});
k_gs_ns_ks
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
k_g_n_k
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
k_g_n_k
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
const
size_t
&
g0
=
idx
[
0
]
/
G1
;
const
size_t
&
g1
=
idx
[
0
]
%
G1
;
const
size_t
&
g2
=
g1
/
(
G1
/
G2
);
self
(
idx
)
=
k_gs_ns_ks
(
g0
,
g2
,
idx
[
1
],
idx
[
2
]);
});
});
v_gs_os_ns
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
v_g_n_o
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
v_g_n_o
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
3
],
idx
[
2
])
=
self
(
idx
);
const
size_t
&
g0
=
idx
[
0
]
/
G1
;
const
size_t
&
g1
=
idx
[
0
]
%
G1
;
const
size_t
&
g2
=
g1
/
(
G1
/
G2
);
self
(
idx
)
=
v_gs_os_ns
(
g0
,
g2
,
idx
[
2
],
idx
[
1
]);
});
});
d0_gs_ms_ns
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
d0_gs_ms_ns
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
d0_g_m_n
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
d0_g_m_n
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
...
@@ -787,14 +822,18 @@ int run(int argc, char* argv[])
...
@@ -787,14 +822,18 @@ int run(int argc, char* argv[])
#endif
#endif
Tensor
<
OutputDataType
>
qgrad_gs_ms_ks_host_result
(
q_gs_ms_ks_lengths
,
q_gs_ms_ks_strides
);
Tensor
<
OutputDataType
>
qgrad_gs_ms_ks_host_result
(
q_gs_ms_ks_lengths
,
q_gs_ms_ks_strides
);
Tensor
<
OutputDataType
>
kgrad_gs_ns_ks_host_result
(
k_gs_ns_ks_lengths
,
k_gs_ns_ks_strides
);
Tensor
<
OutputDataType
>
kgrad_gs_ns_ks_host_result
(
kgrad_gs_ns_ks_lengths
,
Tensor
<
OutputDataType
>
vgrad_gs_os_ns_host_result
(
v_gs_os_ns_lengths
,
v_gs_os_ns_strides
);
kgrad_gs_ns_ks_strides
);
Tensor
<
OutputDataType
>
vgrad_gs_os_ns_host_result
(
vgrad_gs_os_ns_lengths
,
vgrad_gs_os_ns_strides
);
Tensor
<
Acc0BiasDataType
>
d0grad_gs_ms_ns_host_result
(
d0_gs_ms_ns_lengths
,
Tensor
<
Acc0BiasDataType
>
d0grad_gs_ms_ns_host_result
(
d0_gs_ms_ns_lengths
,
d0_gs_ms_ns_strides
);
d0_gs_ms_ns_strides
);
Tensor
<
OutputDataType
>
qgrad_gs_ms_ks_device_result
(
q_gs_ms_ks_lengths
,
q_gs_ms_ks_strides
);
Tensor
<
OutputDataType
>
qgrad_gs_ms_ks_device_result
(
q_gs_ms_ks_lengths
,
q_gs_ms_ks_strides
);
Tensor
<
OutputDataType
>
kgrad_gs_ns_ks_device_result
(
k_gs_ns_ks_lengths
,
k_gs_ns_ks_strides
);
Tensor
<
OutputDataType
>
kgrad_gs_ns_ks_device_result
(
kgrad_gs_ns_ks_lengths
,
Tensor
<
OutputDataType
>
vgrad_gs_os_ns_device_result
(
v_gs_os_ns_lengths
,
v_gs_os_ns_strides
);
kgrad_gs_ns_ks_strides
);
Tensor
<
OutputDataType
>
vgrad_gs_os_ns_device_result
(
vgrad_gs_os_ns_lengths
,
vgrad_gs_os_ns_strides
);
Tensor
<
Acc0BiasDataType
>
d0grad_gs_ms_ns_device_result
(
d0_gs_ms_ns_lengths
,
Tensor
<
Acc0BiasDataType
>
d0grad_gs_ms_ns_device_result
(
d0_gs_ms_ns_lengths
,
d0_gs_ms_ns_strides
);
d0_gs_ms_ns_strides
);
...
...
example/52_flash_atten_bias/grouped_multihead_attention_bias_backward_v2.cpp
View file @
29398e70
...
@@ -275,6 +275,7 @@ int run(int argc, char* argv[])
...
@@ -275,6 +275,7 @@ int run(int argc, char* argv[])
// y_g0_m_g1_o = permute(y_g0_g1_m_o, [0, 2, 1, 3])
// y_g0_m_g1_o = permute(y_g0_g1_m_o, [0, 2, 1, 3])
float
alpha
=
1.
f
/
std
::
sqrt
(
DIM
);
float
alpha
=
1.
f
/
std
::
sqrt
(
DIM
);
float
p_drop
=
0.0
;
float
p_drop
=
0.0
;
int
h_ratio
=
1
;
// G1 / G2
bool
input_permute
=
true
;
bool
input_permute
=
true
;
bool
output_permute
=
true
;
bool
output_permute
=
true
;
...
@@ -292,25 +293,26 @@ int run(int argc, char* argv[])
...
@@ -292,25 +293,26 @@ int run(int argc, char* argv[])
init_method
=
std
::
stoi
(
argv
[
2
]);
init_method
=
std
::
stoi
(
argv
[
2
]);
time_kernel
=
std
::
stoi
(
argv
[
3
]);
time_kernel
=
std
::
stoi
(
argv
[
3
]);
}
}
else
if
(
argc
==
7
)
else
if
(
argc
==
8
)
{
{
do_verification
=
std
::
stoi
(
argv
[
1
]);
do_verification
=
std
::
stoi
(
argv
[
1
]);
init_method
=
std
::
stoi
(
argv
[
2
]);
init_method
=
std
::
stoi
(
argv
[
2
]);
time_kernel
=
std
::
stoi
(
argv
[
3
]);
time_kernel
=
std
::
stoi
(
argv
[
3
]);
p_drop
=
std
::
stof
(
argv
[
4
]);
p_drop
=
std
::
stof
(
argv
[
4
]);
h_ratio
=
std
::
stof
(
argv
[
5
]);
input_permute
=
std
::
stoi
(
argv
[
5
]);
input_permute
=
std
::
stoi
(
argv
[
6
]);
output_permute
=
std
::
stoi
(
argv
[
6
]);
output_permute
=
std
::
stoi
(
argv
[
7
]);
}
}
else
else
{
{
printf
(
"arg1: verification (0=no, 1=yes)
\n
"
);
printf
(
"arg1: verification (0=no, 1=yes)
\n
"
);
printf
(
"arg2: initialization (0=no init, 1=integer value, 2=decimal value)
\n
"
);
printf
(
"arg2: initialization (0=no init, 1=integer value, 2=decimal value)
\n
"
);
printf
(
"arg3: time kernel (0=no, 1=yes)
\n
"
);
printf
(
"arg3: time kernel (0=no, 1=yes)
\n
"
);
printf
(
"arg4
to 11: M, N, K, O, G0, G1
\n
"
);
printf
(
"arg4
: p_drop
\n
"
);
printf
(
"arg
10: scale (alpha)
\n
"
);
printf
(
"arg
5: h_ratio
\n
"
);
printf
(
"arg
11
to
12
: input / output permute
\n
"
);
printf
(
"arg
6
to
7
: input / output permute
\n
"
);
exit
(
0
);
exit
(
0
);
}
}
...
@@ -382,24 +384,25 @@ int run(int argc, char* argv[])
...
@@ -382,24 +384,25 @@ int run(int argc, char* argv[])
int
K
=
DIM
;
int
K
=
DIM
;
int
O
=
DIM
;
int
O
=
DIM
;
int
G0
=
rand
()
%
4
+
1
;
int
G0
=
rand
()
%
4
+
1
;
int
G1
=
rand
()
%
4
+
1
;
int
G2
=
rand
()
%
4
+
1
;
int
G1
=
G2
*
h_ratio
;
std
::
vector
<
ck
::
index_t
>
q_gs_ms_ks_lengths
{
G0
,
G1
,
M
,
K
};
std
::
vector
<
ck
::
index_t
>
q_gs_ms_ks_lengths
{
G0
,
G1
,
M
,
K
};
std
::
vector
<
ck
::
index_t
>
q_gs_ms_ks_strides
=
std
::
vector
<
ck
::
index_t
>
q_gs_ms_ks_strides
=
input_permute
input_permute
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
*
K
,
K
,
G1
*
K
,
1
}
// Q layout [G0, M, G1, K]
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
*
K
,
K
,
G1
*
K
,
1
}
// Q layout [G0, M, G1, K]
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
M
*
K
,
M
*
K
,
K
,
1
};
// Q layout [G0, G1, M, K]
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
M
*
K
,
M
*
K
,
K
,
1
};
// Q layout [G0, G1, M, K]
std
::
vector
<
ck
::
index_t
>
k_gs_ns_ks_lengths
{
G0
,
G
1
,
N
,
K
};
std
::
vector
<
ck
::
index_t
>
k_gs_ns_ks_lengths
{
G0
,
G
2
,
N
,
K
};
std
::
vector
<
ck
::
index_t
>
k_gs_ns_ks_strides
=
std
::
vector
<
ck
::
index_t
>
k_gs_ns_ks_strides
=
input_permute
input_permute
?
std
::
vector
<
ck
::
index_t
>
{
N
*
G
1
*
K
,
K
,
G
1
*
K
,
1
}
// K layout [G0, N, G
1
, K]
?
std
::
vector
<
ck
::
index_t
>
{
N
*
G
2
*
K
,
K
,
G
2
*
K
,
1
}
// K layout [G0, N, G
2
, K]
:
std
::
vector
<
ck
::
index_t
>
{
G
1
*
N
*
K
,
N
*
K
,
K
,
1
};
// K layout [G0, G
1
, N, K]
:
std
::
vector
<
ck
::
index_t
>
{
G
2
*
N
*
K
,
N
*
K
,
K
,
1
};
// K layout [G0, G
2
, N, K]
std
::
vector
<
ck
::
index_t
>
v_gs_os_ns_lengths
{
G0
,
G
1
,
O
,
N
};
std
::
vector
<
ck
::
index_t
>
v_gs_os_ns_lengths
{
G0
,
G
2
,
O
,
N
};
std
::
vector
<
ck
::
index_t
>
v_gs_os_ns_strides
=
std
::
vector
<
ck
::
index_t
>
v_gs_os_ns_strides
=
input_permute
input_permute
?
std
::
vector
<
ck
::
index_t
>
{
N
*
G
1
*
O
,
O
,
1
,
G
1
*
O
}
// V layout [G0, N, G
1
, O]
?
std
::
vector
<
ck
::
index_t
>
{
N
*
G
2
*
O
,
O
,
1
,
G
2
*
O
}
// V layout [G0, N, G
2
, O]
:
std
::
vector
<
ck
::
index_t
>
{
G
1
*
N
*
O
,
N
*
O
,
1
,
O
};
// V layout [G0, G
1
, N, O]
:
std
::
vector
<
ck
::
index_t
>
{
G
2
*
N
*
O
,
N
*
O
,
1
,
O
};
// V layout [G0, G
2
, N, O]
std
::
vector
<
ck
::
index_t
>
y_gs_ms_os_lengths
{
G0
,
G1
,
M
,
O
};
std
::
vector
<
ck
::
index_t
>
y_gs_ms_os_lengths
{
G0
,
G1
,
M
,
O
};
std
::
vector
<
ck
::
index_t
>
y_gs_ms_os_strides
=
std
::
vector
<
ck
::
index_t
>
y_gs_ms_os_strides
=
...
@@ -418,6 +421,17 @@ int run(int argc, char* argv[])
...
@@ -418,6 +421,17 @@ int run(int argc, char* argv[])
input_permute
input_permute
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
*
N
,
N
,
G1
*
N
,
1
}
// Z layout [G0, M, G1, N]
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
*
N
,
N
,
G1
*
N
,
1
}
// Z layout [G0, M, G1, N]
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
M
*
N
,
M
*
N
,
N
,
1
};
// Z layout [G0, G1, M, N]
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
M
*
N
,
M
*
N
,
N
,
1
};
// Z layout [G0, G1, M, N]
std
::
vector
<
ck
::
index_t
>
kgrad_gs_ns_ks_lengths
{
G0
,
G1
,
N
,
K
};
std
::
vector
<
ck
::
index_t
>
kgrad_gs_ns_ks_strides
=
input_permute
?
std
::
vector
<
ck
::
index_t
>
{
N
*
G1
*
K
,
K
,
G1
*
K
,
1
}
// KGrad layout [G0, N, G1, K]
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
N
*
K
,
N
*
K
,
K
,
1
};
// KGrad layout [G0, G1, N, K]
std
::
vector
<
ck
::
index_t
>
vgrad_gs_os_ns_lengths
{
G0
,
G1
,
O
,
N
};
std
::
vector
<
ck
::
index_t
>
vgrad_gs_os_ns_strides
=
input_permute
?
std
::
vector
<
ck
::
index_t
>
{
N
*
G1
*
O
,
O
,
1
,
G1
*
O
}
// VGrad layout [G0, N, G1, O]
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
N
*
O
,
N
*
O
,
1
,
O
};
// VGrad layout [G0, G1, N, O]
// The softmax stat log-sum-exp (LSE) is used to speed up softmax calculation in backward
// The softmax stat log-sum-exp (LSE) is used to speed up softmax calculation in backward
// pass Pi = exp(Si) / sum(exp(S0) + exp(S1) + ...)
// pass Pi = exp(Si) / sum(exp(S0) + exp(S1) + ...)
// = exp(Si) / exp(log(sum(exp() + ...)))
// = exp(Si) / exp(log(sum(exp() + ...)))
...
@@ -439,6 +453,10 @@ int run(int argc, char* argv[])
...
@@ -439,6 +453,10 @@ int run(int argc, char* argv[])
y_gs_ms_os_strides
,
y_gs_ms_os_strides
,
lse_gs_ms_lengths
,
lse_gs_ms_lengths
,
lse_gs_ms_strides
,
lse_gs_ms_strides
,
kgrad_gs_ns_ks_lengths
,
kgrad_gs_ns_ks_strides
,
vgrad_gs_os_ns_lengths
,
vgrad_gs_os_ns_strides
,
d0_gs_ms_ns_lengths
,
d0_gs_ms_ns_lengths
,
d0_gs_ms_ns_strides
,
d0_gs_ms_ns_strides
,
{},
// acc1_bias_gs_ms_os_lengths,
{},
// acc1_bias_gs_ms_os_lengths,
...
@@ -464,6 +482,8 @@ int run(int argc, char* argv[])
...
@@ -464,6 +482,8 @@ int run(int argc, char* argv[])
Tensor
<
InputDataType
>
y_gs_ms_os
(
y_gs_ms_os_lengths
,
y_gs_ms_os_strides
);
Tensor
<
InputDataType
>
y_gs_ms_os
(
y_gs_ms_os_lengths
,
y_gs_ms_os_strides
);
Tensor
<
InputDataType
>
ygrad_gs_ms_os
(
y_gs_ms_os_lengths
,
y_gs_ms_os_strides
);
Tensor
<
InputDataType
>
ygrad_gs_ms_os
(
y_gs_ms_os_lengths
,
y_gs_ms_os_strides
);
Tensor
<
LSEDataType
>
lse_gs_ms
(
lse_gs_ms_lengths
,
lse_gs_ms_strides
);
Tensor
<
LSEDataType
>
lse_gs_ms
(
lse_gs_ms_lengths
,
lse_gs_ms_strides
);
Tensor
<
OutputDataType
>
kgrad_gs_ns_ks
(
kgrad_gs_ns_ks_lengths
,
kgrad_gs_ns_ks_strides
);
Tensor
<
OutputDataType
>
vgrad_gs_os_ns
(
vgrad_gs_os_ns_lengths
,
vgrad_gs_os_ns_strides
);
if
(
i
<
4
)
if
(
i
<
4
)
{
{
std
::
cout
<<
"q_gs_ms_ks: "
<<
q_gs_ms_ks
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"q_gs_ms_ks: "
<<
q_gs_ms_ks
.
mDesc
<<
std
::
endl
;
...
@@ -473,6 +493,8 @@ int run(int argc, char* argv[])
...
@@ -473,6 +493,8 @@ int run(int argc, char* argv[])
std
::
cout
<<
"v_gs_os_ns: "
<<
v_gs_os_ns
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"v_gs_os_ns: "
<<
v_gs_os_ns
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"y_gs_ms_os: "
<<
y_gs_ms_os
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"y_gs_ms_os: "
<<
y_gs_ms_os
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"lse_gs_ms_os: "
<<
lse_gs_ms
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"lse_gs_ms_os: "
<<
lse_gs_ms
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"kgrad_gs_ns_ks: "
<<
kgrad_gs_ns_ks
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"vgrad_gs_os_ns: "
<<
vgrad_gs_os_ns
.
mDesc
<<
std
::
endl
;
}
}
z_gs_ms_ns
.
GenerateTensorValue
(
GeneratorTensor_1
<
InputDataType
>
{
0
});
z_gs_ms_ns
.
GenerateTensorValue
(
GeneratorTensor_1
<
InputDataType
>
{
0
});
switch
(
init_method
)
switch
(
init_method
)
...
@@ -558,14 +580,22 @@ int run(int argc, char* argv[])
...
@@ -558,14 +580,22 @@ int run(int argc, char* argv[])
q_gs_ms_ks
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
q_gs_ms_ks
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
q_g_m_k
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
q_g_m_k
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
});
});
k_gs_ns_ks
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
k_g_n_k
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
k_g_n_k
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
const
size_t
&
g0
=
idx
[
0
]
/
G1
;
const
size_t
&
g1
=
idx
[
0
]
%
G1
;
const
size_t
&
g2
=
g1
/
h_ratio
;
self
(
idx
)
=
k_gs_ns_ks
(
g0
,
g2
,
idx
[
1
],
idx
[
2
]);
});
});
d0_gs_ms_ns
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
d0_gs_ms_ns
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
d0_g_m_n
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
d0_g_m_n
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
});
});
v_gs_os_ns
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
v_g_n_o
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
v_g_n_o
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
3
],
idx
[
2
])
=
self
(
idx
);
const
size_t
&
g0
=
idx
[
0
]
/
G1
;
const
size_t
&
g1
=
idx
[
0
]
%
G1
;
const
size_t
&
g2
=
g1
/
h_ratio
;
self
(
idx
)
=
v_gs_os_ns
(
g0
,
g2
,
idx
[
2
],
idx
[
1
]);
});
});
q_g_m_ks
.
push_back
(
q_g_m_k
);
q_g_m_ks
.
push_back
(
q_g_m_k
);
...
@@ -586,6 +616,8 @@ int run(int argc, char* argv[])
...
@@ -586,6 +616,8 @@ int run(int argc, char* argv[])
z_tensors
.
push_back
(
z_gs_ms_ns
);
z_tensors
.
push_back
(
z_gs_ms_ns
);
lse_tensors
.
push_back
(
lse_gs_ms
);
lse_tensors
.
push_back
(
lse_gs_ms
);
ygrad_tensors
.
push_back
(
ygrad_gs_ms_os
);
ygrad_tensors
.
push_back
(
ygrad_gs_ms_os
);
kgrad_tensors
.
push_back
(
kgrad_gs_ns_ks
);
vgrad_tensors
.
push_back
(
vgrad_gs_os_ns
);
q_tensors_device
.
emplace_back
(
q_tensors_device
.
emplace_back
(
std
::
make_unique
<
DeviceMem
>
(
sizeof
(
InputDataType
)
*
q_gs_ms_ks
.
GetElementSpaceSize
()));
std
::
make_unique
<
DeviceMem
>
(
sizeof
(
InputDataType
)
*
q_gs_ms_ks
.
GetElementSpaceSize
()));
k_tensors_device
.
emplace_back
(
k_tensors_device
.
emplace_back
(
...
@@ -602,12 +634,12 @@ int run(int argc, char* argv[])
...
@@ -602,12 +634,12 @@ int run(int argc, char* argv[])
std
::
make_unique
<
DeviceMem
>
(
sizeof
(
LSEDataType
)
*
lse_gs_ms
.
GetElementSpaceSize
()));
std
::
make_unique
<
DeviceMem
>
(
sizeof
(
LSEDataType
)
*
lse_gs_ms
.
GetElementSpaceSize
()));
qgrad_tensors_device
.
emplace_back
(
qgrad_tensors_device
.
emplace_back
(
std
::
make_unique
<
DeviceMem
>
(
sizeof
(
OutputDataType
)
*
q_gs_ms_ks
.
GetElementSpaceSize
()));
std
::
make_unique
<
DeviceMem
>
(
sizeof
(
OutputDataType
)
*
q_gs_ms_ks
.
GetElementSpaceSize
()));
kgrad_tensors_device
.
emplace_back
(
kgrad_tensors_device
.
emplace_back
(
std
::
make_unique
<
DeviceMem
>
(
std
::
make_unique
<
DeviceMem
>
(
sizeof
(
OutputDataType
)
*
k_gs_ns_ks
.
GetElementSpaceSize
()));
sizeof
(
OutputDataType
)
*
k
grad
_gs_ns_ks
.
GetElementSpaceSize
()));
d0grad_tensors_device
.
emplace_back
(
std
::
make_unique
<
DeviceMem
>
(
d0grad_tensors_device
.
emplace_back
(
std
::
make_unique
<
DeviceMem
>
(
sizeof
(
Acc0BiasDataType
)
*
d0_gs_ms_ns
.
GetElementSpaceSize
()));
sizeof
(
Acc0BiasDataType
)
*
d0_gs_ms_ns
.
GetElementSpaceSize
()));
vgrad_tensors_device
.
emplace_back
(
vgrad_tensors_device
.
emplace_back
(
std
::
make_unique
<
DeviceMem
>
(
std
::
make_unique
<
DeviceMem
>
(
sizeof
(
OutputDataType
)
*
v_gs_os_ns
.
GetElementSpaceSize
()));
sizeof
(
OutputDataType
)
*
v
grad
_gs_os_ns
.
GetElementSpaceSize
()));
ygrad_tensors_device
.
emplace_back
(
ygrad_tensors_device
.
emplace_back
(
std
::
make_unique
<
DeviceMem
>
(
sizeof
(
InputDataType
)
*
y_gs_ms_os
.
GetElementSpaceSize
()));
std
::
make_unique
<
DeviceMem
>
(
sizeof
(
InputDataType
)
*
y_gs_ms_os
.
GetElementSpaceSize
()));
q_tensors_device
.
back
()
->
ToDevice
(
q_gs_ms_ks
.
data
());
q_tensors_device
.
back
()
->
ToDevice
(
q_gs_ms_ks
.
data
());
...
@@ -652,6 +684,7 @@ int run(int argc, char* argv[])
...
@@ -652,6 +684,7 @@ int run(int argc, char* argv[])
QKVElementOp
{},
QKVElementOp
{},
YElementOp
{},
YElementOp
{},
p_drop
,
p_drop
,
h_ratio
,
std
::
tuple
<
unsigned
long
long
,
unsigned
long
long
>
(
seed
,
offset
));
std
::
tuple
<
unsigned
long
long
,
unsigned
long
long
>
(
seed
,
offset
));
DeviceMem
problem_desc_workspace
(
gemm
.
GetWorkSpaceSize
(
&
argument
));
DeviceMem
problem_desc_workspace
(
gemm
.
GetWorkSpaceSize
(
&
argument
));
...
@@ -700,6 +733,7 @@ int run(int argc, char* argv[])
...
@@ -700,6 +733,7 @@ int run(int argc, char* argv[])
QKVElementOp
{},
QKVElementOp
{},
YElementOp
{},
YElementOp
{},
p_drop
,
p_drop
,
h_ratio
,
std
::
tuple
<
unsigned
long
long
,
unsigned
long
long
>
(
seed
,
offset
));
std
::
tuple
<
unsigned
long
long
,
unsigned
long
long
>
(
seed
,
offset
));
DeviceMem
problem_desc_workspace_verify
(
gemm
.
GetWorkSpaceSize
(
&
argument
));
DeviceMem
problem_desc_workspace_verify
(
gemm
.
GetWorkSpaceSize
(
&
argument
));
gemm
.
SetWorkSpacePointer
(
&
argument
,
problem_desc_workspace_verify
.
GetDeviceBuffer
());
gemm
.
SetWorkSpacePointer
(
&
argument
,
problem_desc_workspace_verify
.
GetDeviceBuffer
());
...
@@ -713,7 +747,7 @@ int run(int argc, char* argv[])
...
@@ -713,7 +747,7 @@ int run(int argc, char* argv[])
for
(
std
::
size_t
i
=
0
;
i
<
group_count
;
i
++
)
for
(
std
::
size_t
i
=
0
;
i
<
group_count
;
i
++
)
{
{
int
G1
=
v
_tensors
[
i
].
GetLengths
()[
1
];
int
G1
=
q
_tensors
[
i
].
GetLengths
()[
1
];
// copy z matirx data form device
// copy z matirx data form device
z_tensors_device
[
i
]
->
FromDevice
(
z_tensors
[
i
].
mData
.
data
());
z_tensors_device
[
i
]
->
FromDevice
(
z_tensors
[
i
].
mData
.
data
());
z_tensors
[
i
].
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
z_tensors
[
i
].
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
...
@@ -752,8 +786,8 @@ int run(int argc, char* argv[])
...
@@ -752,8 +786,8 @@ int run(int argc, char* argv[])
for
(
std
::
size_t
i
=
0
;
i
<
group_count
;
i
++
)
for
(
std
::
size_t
i
=
0
;
i
<
group_count
;
i
++
)
{
{
int
G0
=
v
_tensors
[
i
].
GetLengths
()[
0
];
int
G0
=
q
_tensors
[
i
].
GetLengths
()[
0
];
int
G1
=
v
_tensors
[
i
].
GetLengths
()[
1
];
int
G1
=
q
_tensors
[
i
].
GetLengths
()[
1
];
int
O
=
v_tensors
[
i
].
GetLengths
()[
2
];
int
O
=
v_tensors
[
i
].
GetLengths
()[
2
];
int
N
=
v_tensors
[
i
].
GetLengths
()[
3
];
int
N
=
v_tensors
[
i
].
GetLengths
()[
3
];
int
M
=
q_tensors
[
i
].
GetLengths
()[
2
];
int
M
=
q_tensors
[
i
].
GetLengths
()[
2
];
...
@@ -814,21 +848,21 @@ int run(int argc, char* argv[])
...
@@ -814,21 +848,21 @@ int run(int argc, char* argv[])
Tensor
<
OutputDataType
>
qgrad_gs_ms_ks_host_result
(
q_tensors
[
i
].
GetLengths
(),
Tensor
<
OutputDataType
>
qgrad_gs_ms_ks_host_result
(
q_tensors
[
i
].
GetLengths
(),
q_tensors
[
i
].
GetStrides
());
q_tensors
[
i
].
GetStrides
());
Tensor
<
OutputDataType
>
kgrad_gs_ns_ks_host_result
(
k_tensors
[
i
].
GetLengths
(),
Tensor
<
OutputDataType
>
kgrad_gs_ns_ks_host_result
(
k
grad
_tensors
[
i
].
GetLengths
(),
k_tensors
[
i
].
GetStrides
());
k
grad
_tensors
[
i
].
GetStrides
());
Tensor
<
Acc0BiasDataType
>
d0grad_gs_ms_ns_host_result
(
d0_tensors
[
i
].
GetLengths
(),
Tensor
<
Acc0BiasDataType
>
d0grad_gs_ms_ns_host_result
(
d0_tensors
[
i
].
GetLengths
(),
d0_tensors
[
i
].
GetStrides
());
d0_tensors
[
i
].
GetStrides
());
Tensor
<
OutputDataType
>
vgrad_gs_os_ns_host_result
(
v_tensors
[
i
].
GetLengths
(),
Tensor
<
OutputDataType
>
vgrad_gs_os_ns_host_result
(
v
grad
_tensors
[
i
].
GetLengths
(),
v_tensors
[
i
].
GetStrides
());
v
grad
_tensors
[
i
].
GetStrides
());
Tensor
<
OutputDataType
>
qgrad_gs_ms_ks_device_result
(
q_tensors
[
i
].
GetLengths
(),
Tensor
<
OutputDataType
>
qgrad_gs_ms_ks_device_result
(
q_tensors
[
i
].
GetLengths
(),
q_tensors
[
i
].
GetStrides
());
q_tensors
[
i
].
GetStrides
());
Tensor
<
OutputDataType
>
kgrad_gs_ns_ks_device_result
(
k_tensors
[
i
].
GetLengths
(),
Tensor
<
OutputDataType
>
kgrad_gs_ns_ks_device_result
(
k
grad
_tensors
[
i
].
GetLengths
(),
k_tensors
[
i
].
GetStrides
());
k
grad
_tensors
[
i
].
GetStrides
());
Tensor
<
Acc0BiasDataType
>
d0grad_gs_ms_ns_device_result
(
d0_tensors
[
i
].
GetLengths
(),
Tensor
<
Acc0BiasDataType
>
d0grad_gs_ms_ns_device_result
(
d0_tensors
[
i
].
GetLengths
(),
d0_tensors
[
i
].
GetStrides
());
d0_tensors
[
i
].
GetStrides
());
Tensor
<
OutputDataType
>
vgrad_gs_os_ns_device_result
(
v_tensors
[
i
].
GetLengths
(),
Tensor
<
OutputDataType
>
vgrad_gs_os_ns_device_result
(
v
grad
_tensors
[
i
].
GetLengths
(),
v_tensors
[
i
].
GetStrides
());
v
grad
_tensors
[
i
].
GetStrides
());
qgrad_tensors_device
[
i
]
->
FromDevice
(
qgrad_gs_ms_ks_device_result
.
data
());
qgrad_tensors_device
[
i
]
->
FromDevice
(
qgrad_gs_ms_ks_device_result
.
data
());
kgrad_tensors_device
[
i
]
->
FromDevice
(
kgrad_gs_ns_ks_device_result
.
data
());
kgrad_tensors_device
[
i
]
->
FromDevice
(
kgrad_gs_ns_ks_device_result
.
data
());
...
...
example/52_flash_atten_bias/run_batched_multihead_attention_bias_forward_v2.inc
View file @
29398e70
...
@@ -18,7 +18,8 @@ int run(int argc, char* argv[])
...
@@ -18,7 +18,8 @@ int run(int argc, char* argv[])
// C_g0_g1_m_o = reshape(C_g_m_o, [g0, g1, m, o])
// C_g0_g1_m_o = reshape(C_g_m_o, [g0, g1, m, o])
// C_g0_m_g1_o = permute(C_g0_g1_m_o, [0, 2, 1, 3])
// C_g0_m_g1_o = permute(C_g0_g1_m_o, [0, 2, 1, 3])
ck
::
index_t
G0
=
7
;
ck
::
index_t
G0
=
7
;
ck
::
index_t
G1
=
13
;
ck
::
index_t
G1
=
12
;
// h_q
ck
::
index_t
G2
=
12
;
// h_kv
bool
input_permute
=
false
;
bool
input_permute
=
false
;
bool
output_permute
=
true
;
bool
output_permute
=
true
;
...
@@ -37,7 +38,7 @@ int run(int argc, char* argv[])
...
@@ -37,7 +38,7 @@ int run(int argc, char* argv[])
init_method
=
std
::
stoi
(
argv
[
2
]);
init_method
=
std
::
stoi
(
argv
[
2
]);
time_kernel
=
std
::
stoi
(
argv
[
3
]);
time_kernel
=
std
::
stoi
(
argv
[
3
]);
}
}
else
if
(
argc
==
1
3
)
else
if
(
argc
==
1
4
)
{
{
do_verification
=
std
::
stoi
(
argv
[
1
]);
do_verification
=
std
::
stoi
(
argv
[
1
]);
init_method
=
std
::
stoi
(
argv
[
2
]);
init_method
=
std
::
stoi
(
argv
[
2
]);
...
@@ -49,20 +50,21 @@ int run(int argc, char* argv[])
...
@@ -49,20 +50,21 @@ int run(int argc, char* argv[])
O
=
std
::
stoi
(
argv
[
7
]);
O
=
std
::
stoi
(
argv
[
7
]);
G0
=
std
::
stoi
(
argv
[
8
]);
G0
=
std
::
stoi
(
argv
[
8
]);
G1
=
std
::
stoi
(
argv
[
9
]);
G1
=
std
::
stoi
(
argv
[
9
]);
G2
=
std
::
stoi
(
argv
[
10
]);
p_drop
=
std
::
stof
(
argv
[
1
0
]);
p_drop
=
std
::
stof
(
argv
[
1
1
]);
input_permute
=
std
::
stoi
(
argv
[
1
1
]);
input_permute
=
std
::
stoi
(
argv
[
1
2
]);
output_permute
=
std
::
stoi
(
argv
[
1
2
]);
output_permute
=
std
::
stoi
(
argv
[
1
3
]);
}
}
else
else
{
{
printf
(
"arg1: verification (0=no, 1=yes)
\n
"
);
printf
(
"arg1: verification (0=no, 1=yes)
\n
"
);
printf
(
"arg2: initialization (0=no init, 1=integer value, 2=decimal value)
\n
"
);
printf
(
"arg2: initialization (0=no init, 1=integer value, 2=decimal value)
\n
"
);
printf
(
"arg3: time kernel (0=no, 1=yes)
\n
"
);
printf
(
"arg3: time kernel (0=no, 1=yes)
\n
"
);
printf
(
"arg4 to 1
1
: M, N, K, O, G0, G1
\n
"
);
printf
(
"arg4 to 1
0
: M, N, K, O, G0, G1
, G2
\n
"
);
printf
(
"arg1
0
:
scale (alpha)
\n
"
);
printf
(
"arg1
1
:
p_drop
\n
"
);
printf
(
"arg1
1
to 1
2
: input / output permute
\n
"
);
printf
(
"arg1
2
to 1
3
: input / output permute
\n
"
);
exit
(
0
);
exit
(
0
);
}
}
...
@@ -77,17 +79,17 @@ int run(int argc, char* argv[])
...
@@ -77,17 +79,17 @@ int run(int argc, char* argv[])
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
*
K
,
K
,
G1
*
K
,
1
}
// A layout [G0, M, G1, K]
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
*
K
,
K
,
G1
*
K
,
1
}
// A layout [G0, M, G1, K]
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
M
*
K
,
M
*
K
,
K
,
1
};
// A layout [G0, G1, M, K]
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
M
*
K
,
M
*
K
,
K
,
1
};
// A layout [G0, G1, M, K]
std
::
vector
<
ck
::
index_t
>
b0_gs_ns_ks_lengths
{
G0
,
G
1
,
N
,
K
};
std
::
vector
<
ck
::
index_t
>
b0_gs_ns_ks_lengths
{
G0
,
G
2
,
N
,
K
};
std
::
vector
<
ck
::
index_t
>
b0_gs_ns_ks_strides
=
std
::
vector
<
ck
::
index_t
>
b0_gs_ns_ks_strides
=
input_permute
input_permute
?
std
::
vector
<
ck
::
index_t
>
{
N
*
G
1
*
K
,
K
,
G
1
*
K
,
1
}
// B0 layout [G0, N, G
1
, K]
?
std
::
vector
<
ck
::
index_t
>
{
N
*
G
2
*
K
,
K
,
G
2
*
K
,
1
}
// B0 layout [G0, N, G
2
, K]
:
std
::
vector
<
ck
::
index_t
>
{
G
1
*
N
*
K
,
N
*
K
,
K
,
1
};
// B0 layout [G0, G
1
, N, K]
:
std
::
vector
<
ck
::
index_t
>
{
G
2
*
N
*
K
,
N
*
K
,
K
,
1
};
// B0 layout [G0, G
2
, N, K]
std
::
vector
<
ck
::
index_t
>
b1_gs_os_ns_lengths
{
G0
,
G
1
,
O
,
N
};
std
::
vector
<
ck
::
index_t
>
b1_gs_os_ns_lengths
{
G0
,
G
2
,
O
,
N
};
std
::
vector
<
ck
::
index_t
>
b1_gs_os_ns_strides
=
std
::
vector
<
ck
::
index_t
>
b1_gs_os_ns_strides
=
input_permute
input_permute
?
std
::
vector
<
ck
::
index_t
>
{
N
*
G
1
*
O
,
O
,
1
,
G
1
*
O
}
// B1 layout [G0, N, G
1
, O]
?
std
::
vector
<
ck
::
index_t
>
{
N
*
G
2
*
O
,
O
,
1
,
G
2
*
O
}
// B1 layout [G0, N, G
2
, O]
:
std
::
vector
<
ck
::
index_t
>
{
G
1
*
N
*
O
,
N
*
O
,
1
,
O
};
// B1 layout [G0, G
1
, N, O]
:
std
::
vector
<
ck
::
index_t
>
{
G
2
*
N
*
O
,
N
*
O
,
1
,
O
};
// B1 layout [G0, G
2
, N, O]
std
::
vector
<
ck
::
index_t
>
c_gs_ms_os_lengths
{
G0
,
G1
,
M
,
O
};
std
::
vector
<
ck
::
index_t
>
c_gs_ms_os_lengths
{
G0
,
G1
,
M
,
O
};
std
::
vector
<
ck
::
index_t
>
c_gs_ms_os_strides
=
std
::
vector
<
ck
::
index_t
>
c_gs_ms_os_strides
=
...
@@ -323,11 +325,19 @@ int run(int argc, char* argv[])
...
@@ -323,11 +325,19 @@ int run(int argc, char* argv[])
a_gs_ms_ks
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
a_gs_ms_ks
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
a_g_m_k
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
a_g_m_k
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
});
});
b0_gs_ns_ks
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
b0_g_k_n
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
b0_g_k_n
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
3
],
idx
[
2
])
=
self
(
idx
);
const
size_t
&
g0
=
idx
[
0
]
/
G1
;
const
size_t
&
g1
=
idx
[
0
]
%
G1
;
const
size_t
&
g2
=
g1
/
(
G1
/
G2
);
self
(
idx
)
=
b0_gs_ns_ks
(
g0
,
g2
,
idx
[
2
],
idx
[
1
]);
});
});
b1_gs_os_ns
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
b1_g_n_o
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
b1_g_n_o
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
3
],
idx
[
2
])
=
self
(
idx
);
const
size_t
&
g0
=
idx
[
0
]
/
G1
;
const
size_t
&
g1
=
idx
[
0
]
%
G1
;
const
size_t
&
g2
=
g1
/
(
G1
/
G2
);
self
(
idx
)
=
b1_gs_os_ns
(
g0
,
g2
,
idx
[
2
],
idx
[
1
]);
});
});
d_gs_ms_ns
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
d_gs_ms_ns
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
d_g_m_n
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
d_g_m_n
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
...
...
example/52_flash_atten_bias/run_grouped_multihead_attention_bias_forward_v2.inc
View file @
29398e70
...
@@ -11,6 +11,7 @@ int run(int argc, char* argv[])
...
@@ -11,6 +11,7 @@ int run(int argc, char* argv[])
bool
output_permute
=
true
;
bool
output_permute
=
true
;
float
p_drop
=
0.2
;
float
p_drop
=
0.2
;
int
h_ratio
=
1
;
// G1 / G2
const
unsigned
long
long
seed
=
1
;
const
unsigned
long
long
seed
=
1
;
const
unsigned
long
long
offset
=
0
;
const
unsigned
long
long
offset
=
0
;
...
@@ -24,22 +25,25 @@ int run(int argc, char* argv[])
...
@@ -24,22 +25,25 @@ int run(int argc, char* argv[])
init_method
=
std
::
stoi
(
argv
[
2
]);
init_method
=
std
::
stoi
(
argv
[
2
]);
time_kernel
=
std
::
stoi
(
argv
[
3
]);
time_kernel
=
std
::
stoi
(
argv
[
3
]);
}
}
else
if
(
argc
==
7
)
else
if
(
argc
==
8
)
{
{
do_verification
=
std
::
stoi
(
argv
[
1
]);
do_verification
=
std
::
stoi
(
argv
[
1
]);
init_method
=
std
::
stoi
(
argv
[
2
]);
init_method
=
std
::
stoi
(
argv
[
2
]);
time_kernel
=
std
::
stoi
(
argv
[
3
]);
time_kernel
=
std
::
stoi
(
argv
[
3
]);
p_drop
=
std
::
stoi
(
argv
[
4
]);
p_drop
=
std
::
stoi
(
argv
[
4
]);
input_permute
=
std
::
stoi
(
argv
[
5
]);
h_ratio
=
std
::
stof
(
argv
[
5
]);
output_permute
=
std
::
stoi
(
argv
[
6
]);
input_permute
=
std
::
stoi
(
argv
[
6
]);
output_permute
=
std
::
stoi
(
argv
[
7
]);
}
}
else
else
{
{
printf
(
"arg1: verification (0=no, 1=yes)
\n
"
);
printf
(
"arg1: verification (0=no, 1=yes)
\n
"
);
printf
(
"arg2: initialization (0=no init, 1=integer value, 2=decimal value)
\n
"
);
printf
(
"arg2: initialization (0=no init, 1=integer value, 2=decimal value)
\n
"
);
printf
(
"arg3: time kernel (0=no, 1=yes)
\n
"
);
printf
(
"arg3: time kernel (0=no, 1=yes)
\n
"
);
printf
(
"arg4 to 5: input / output permute
\n
"
);
printf
(
"arg4: p_drop
\n
"
);
printf
(
"arg5: h_ratio
\n
"
);
printf
(
"arg6 to 7: input / output permute
\n
"
);
exit
(
0
);
exit
(
0
);
}
}
...
@@ -91,7 +95,8 @@ int run(int argc, char* argv[])
...
@@ -91,7 +95,8 @@ int run(int argc, char* argv[])
int
K
=
DIM
;
int
K
=
DIM
;
int
O
=
DIM
;
int
O
=
DIM
;
int
G0
=
rand
()
%
3
+
1
;
int
G0
=
rand
()
%
3
+
1
;
int
G1
=
rand
()
%
5
+
1
;
int
G2
=
rand
()
%
5
+
1
;
int
G1
=
G2
*
h_ratio
;
g0_g1_m_n_k_o
.
push_back
({
G0
,
G1
,
M
,
N
,
K
,
O
});
g0_g1_m_n_k_o
.
push_back
({
G0
,
G1
,
M
,
N
,
K
,
O
});
...
@@ -101,17 +106,17 @@ int run(int argc, char* argv[])
...
@@ -101,17 +106,17 @@ int run(int argc, char* argv[])
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
*
K
,
K
,
G1
*
K
,
1
}
// A layout [G0, M, G1, K]
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
*
K
,
K
,
G1
*
K
,
1
}
// A layout [G0, M, G1, K]
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
M
*
K
,
M
*
K
,
K
,
1
};
// A layout [G0, G1, M, K]
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
M
*
K
,
M
*
K
,
K
,
1
};
// A layout [G0, G1, M, K]
std
::
vector
<
ck
::
index_t
>
b0_gs_ns_ks_lengths
{
G0
,
G
1
,
N
,
K
};
std
::
vector
<
ck
::
index_t
>
b0_gs_ns_ks_lengths
{
G0
,
G
2
,
N
,
K
};
std
::
vector
<
ck
::
index_t
>
b0_gs_ns_ks_strides
=
std
::
vector
<
ck
::
index_t
>
b0_gs_ns_ks_strides
=
input_permute
input_permute
?
std
::
vector
<
ck
::
index_t
>
{
N
*
G
1
*
K
,
K
,
G
1
*
K
,
1
}
// B0 layout [G0, N, G
1
, K]
?
std
::
vector
<
ck
::
index_t
>
{
N
*
G
2
*
K
,
K
,
G
2
*
K
,
1
}
// B0 layout [G0, N, G
2
, K]
:
std
::
vector
<
ck
::
index_t
>
{
G
1
*
N
*
K
,
N
*
K
,
K
,
1
};
// B0 layout [G0, G
1
, N, K]
:
std
::
vector
<
ck
::
index_t
>
{
G
2
*
N
*
K
,
N
*
K
,
K
,
1
};
// B0 layout [G0, G
2
, N, K]
std
::
vector
<
ck
::
index_t
>
b1_gs_os_ns_lengths
{
G0
,
G
1
,
O
,
N
};
std
::
vector
<
ck
::
index_t
>
b1_gs_os_ns_lengths
{
G0
,
G
2
,
O
,
N
};
std
::
vector
<
ck
::
index_t
>
b1_gs_os_ns_strides
=
std
::
vector
<
ck
::
index_t
>
b1_gs_os_ns_strides
=
input_permute
input_permute
?
std
::
vector
<
ck
::
index_t
>
{
N
*
G
1
*
O
,
O
,
1
,
G
1
*
O
}
// B1 layout [G0, N, G
1
, O]
?
std
::
vector
<
ck
::
index_t
>
{
N
*
G
2
*
O
,
O
,
1
,
G
2
*
O
}
// B1 layout [G0, N, G
2
, O]
:
std
::
vector
<
ck
::
index_t
>
{
G
1
*
N
*
O
,
N
*
O
,
1
,
O
};
// B1 layout [G0, G
1
, N, O]
:
std
::
vector
<
ck
::
index_t
>
{
G
2
*
N
*
O
,
N
*
O
,
1
,
O
};
// B1 layout [G0, G
2
, N, O]
std
::
vector
<
ck
::
index_t
>
c_gs_ms_os_lengths
{
G0
,
G1
,
M
,
O
};
std
::
vector
<
ck
::
index_t
>
c_gs_ms_os_lengths
{
G0
,
G1
,
M
,
O
};
std
::
vector
<
ck
::
index_t
>
c_gs_ms_os_strides
=
std
::
vector
<
ck
::
index_t
>
c_gs_ms_os_strides
=
...
@@ -276,6 +281,7 @@ int run(int argc, char* argv[])
...
@@ -276,6 +281,7 @@ int run(int argc, char* argv[])
b1_element_op
,
b1_element_op
,
c_element_op
,
c_element_op
,
p_drop
,
// dropout ratio
p_drop
,
// dropout ratio
h_ratio
,
{
seed
,
offset
});
// dropout random seed and offset, offset should be
{
seed
,
offset
});
// dropout random seed and offset, offset should be
// at least the number of elements on a thread
// at least the number of elements on a thread
...
@@ -331,6 +337,7 @@ int run(int argc, char* argv[])
...
@@ -331,6 +337,7 @@ int run(int argc, char* argv[])
b1_element_op
,
b1_element_op
,
c_element_op
,
c_element_op
,
p_drop
,
// dropout ratio
p_drop
,
// dropout ratio
h_ratio
,
{
seed
,
offset
});
// dropout random seed and offset, offset should be
{
seed
,
offset
});
// dropout random seed and offset, offset should be
// at least the number of elements on a thread
// at least the number of elements on a thread
// specify workspace for problem_desc
// specify workspace for problem_desc
...
@@ -395,13 +402,20 @@ int run(int argc, char* argv[])
...
@@ -395,13 +402,20 @@ int run(int argc, char* argv[])
a_gs_ms_ks
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
a_gs_ms_ks
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
a_g_m_k
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
a_g_m_k
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
});
});
b0_gs_ns_ks
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
b0_g_k_n
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
b0_g_k_n
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
3
],
idx
[
2
])
=
self
(
idx
);
const
size_t
&
g0
=
idx
[
0
]
/
G1
;
});
const
size_t
&
g1
=
idx
[
0
]
%
G1
;
b1_gs_os_ns
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
const
size_t
&
g2
=
g1
/
h_ratio
;
b1_g_n_o
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
3
],
idx
[
2
])
=
self
(
idx
);
self
(
idx
)
=
b0_gs_ns_ks
(
g0
,
g2
,
idx
[
2
],
idx
[
1
]);
});
});
b1_g_n_o
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
const
size_t
&
g0
=
idx
[
0
]
/
G1
;
const
size_t
&
g1
=
idx
[
0
]
%
G1
;
const
size_t
&
g2
=
g1
/
h_ratio
;
self
(
idx
)
=
b1_gs_os_ns
(
g0
,
g2
,
idx
[
2
],
idx
[
1
]);
});
d_gs_ms_ns
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
d_gs_ms_ns
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
d_g_m_n
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
d_g_m_n
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
});
});
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment