Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
104aeabc
Commit
104aeabc
authored
Sep 27, 2023
by
danyao12
Browse files
split bwd mqa/gqa
parent
198558c5
Changes
9
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
482 additions
and
114 deletions
+482
-114
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_backward_v2.cpp
..._softmax_gemm/batched_multihead_attention_backward_v2.cpp
+4
-4
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_backward_v3.cpp
..._softmax_gemm/batched_multihead_attention_backward_v3.cpp
+65
-24
example/32_batched_gemm_scale_softmax_gemm/grouped_multihead_attention_backward_v3.cpp
..._softmax_gemm/grouped_multihead_attention_backward_v3.cpp
+68
-34
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_bwd_xdl_cshuffle_qloop_light_v1.hpp
...pl/device_batched_mha_bwd_xdl_cshuffle_qloop_light_v1.hpp
+95
-12
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_bwd_xdl_cshuffle_qloop_light_v2.hpp
...pl/device_batched_mha_bwd_xdl_cshuffle_qloop_light_v2.hpp
+92
-8
include/ck/tensor_operation/gpu/device/impl/device_grouped_mha_bwd_xdl_cshuffle_qloop_light_v1.hpp
...pl/device_grouped_mha_bwd_xdl_cshuffle_qloop_light_v1.hpp
+74
-12
include/ck/tensor_operation/gpu/device/impl/device_grouped_mha_bwd_xdl_cshuffle_qloop_light_v2.hpp
...pl/device_grouped_mha_bwd_xdl_cshuffle_qloop_light_v2.hpp
+72
-12
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_light_v1.hpp
...dwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_light_v1.hpp
+6
-4
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_light_v2.hpp
...dwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_light_v2.hpp
+6
-4
No files found.
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_backward_v2.cpp
View file @
104aeabc
...
@@ -542,8 +542,8 @@ int run(int argc, char* argv[])
...
@@ -542,8 +542,8 @@ int run(int argc, char* argv[])
static_cast
<
OutputDataType
*>
(
qgrad_device_buf
.
GetDeviceBuffer
()),
static_cast
<
OutputDataType
*>
(
qgrad_device_buf
.
GetDeviceBuffer
()),
static_cast
<
OutputDataType
*>
(
kgrad_device_buf
.
GetDeviceBuffer
()),
static_cast
<
OutputDataType
*>
(
kgrad_device_buf
.
GetDeviceBuffer
()),
static_cast
<
OutputDataType
*>
(
vgrad_device_buf
.
GetDeviceBuffer
()),
static_cast
<
OutputDataType
*>
(
vgrad_device_buf
.
GetDeviceBuffer
()),
nullptr
,
//
p_acc0_bias;
nullptr
,
// p_acc0_bias;
nullptr
,
//
p_acc1_bias;
nullptr
,
// p_acc1_bias;
nullptr
,
nullptr
,
nullptr
,
nullptr
,
q_gs_ms_ks_lengths
,
q_gs_ms_ks_lengths
,
...
@@ -593,8 +593,8 @@ int run(int argc, char* argv[])
...
@@ -593,8 +593,8 @@ int run(int argc, char* argv[])
static_cast
<
OutputDataType
*>
(
qgrad_device_buf
.
GetDeviceBuffer
()),
static_cast
<
OutputDataType
*>
(
qgrad_device_buf
.
GetDeviceBuffer
()),
static_cast
<
OutputDataType
*>
(
kgrad_device_buf
.
GetDeviceBuffer
()),
static_cast
<
OutputDataType
*>
(
kgrad_device_buf
.
GetDeviceBuffer
()),
static_cast
<
OutputDataType
*>
(
vgrad_device_buf
.
GetDeviceBuffer
()),
static_cast
<
OutputDataType
*>
(
vgrad_device_buf
.
GetDeviceBuffer
()),
nullptr
,
//
p_acc0_bias;
nullptr
,
// p_acc0_bias;
nullptr
,
//
p_acc1_bias;
nullptr
,
// p_acc1_bias;
nullptr
,
nullptr
,
nullptr
,
nullptr
,
q_gs_ms_ks_lengths
,
q_gs_ms_ks_lengths
,
...
...
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_backward_v3.cpp
View file @
104aeabc
...
@@ -277,7 +277,8 @@ int run(int argc, char* argv[])
...
@@ -277,7 +277,8 @@ int run(int argc, char* argv[])
ck
::
index_t
K
=
DIM
;
ck
::
index_t
K
=
DIM
;
ck
::
index_t
O
=
DIM
;
ck
::
index_t
O
=
DIM
;
ck
::
index_t
G0
=
4
;
ck
::
index_t
G0
=
4
;
ck
::
index_t
G1
=
6
;
ck
::
index_t
G1
=
6
;
// h_q
ck
::
index_t
G2
=
6
;
// h_kv
bool
input_permute
=
false
;
bool
input_permute
=
false
;
bool
output_permute
=
false
;
bool
output_permute
=
false
;
...
@@ -296,7 +297,7 @@ int run(int argc, char* argv[])
...
@@ -296,7 +297,7 @@ int run(int argc, char* argv[])
init_method
=
std
::
stoi
(
argv
[
2
]);
init_method
=
std
::
stoi
(
argv
[
2
]);
time_kernel
=
std
::
stoi
(
argv
[
3
]);
time_kernel
=
std
::
stoi
(
argv
[
3
]);
}
}
else
if
(
argc
==
1
3
)
else
if
(
argc
==
1
4
)
{
{
do_verification
=
std
::
stoi
(
argv
[
1
]);
do_verification
=
std
::
stoi
(
argv
[
1
]);
init_method
=
std
::
stoi
(
argv
[
2
]);
init_method
=
std
::
stoi
(
argv
[
2
]);
...
@@ -308,20 +309,21 @@ int run(int argc, char* argv[])
...
@@ -308,20 +309,21 @@ int run(int argc, char* argv[])
O
=
std
::
stoi
(
argv
[
7
]);
O
=
std
::
stoi
(
argv
[
7
]);
G0
=
std
::
stoi
(
argv
[
8
]);
G0
=
std
::
stoi
(
argv
[
8
]);
G1
=
std
::
stoi
(
argv
[
9
]);
G1
=
std
::
stoi
(
argv
[
9
]);
G2
=
std
::
stoi
(
argv
[
10
]);
p_drop
=
std
::
stof
(
argv
[
1
0
]);
p_drop
=
std
::
stof
(
argv
[
1
1
]);
input_permute
=
std
::
stoi
(
argv
[
1
1
]);
input_permute
=
std
::
stoi
(
argv
[
1
2
]);
output_permute
=
std
::
stoi
(
argv
[
1
2
]);
output_permute
=
std
::
stoi
(
argv
[
1
3
]);
}
}
else
else
{
{
printf
(
"arg1: verification (0=no, 1=yes)
\n
"
);
printf
(
"arg1: verification (0=no, 1=yes)
\n
"
);
printf
(
"arg2: initialization (0=no init, 1=integer value, 2=decimal value)
\n
"
);
printf
(
"arg2: initialization (0=no init, 1=integer value, 2=decimal value)
\n
"
);
printf
(
"arg3: time kernel (0=no, 1=yes)
\n
"
);
printf
(
"arg3: time kernel (0=no, 1=yes)
\n
"
);
printf
(
"arg4 to 1
1
: M, N, K, O, G0, G1
\n
"
);
printf
(
"arg4 to 1
0
: M, N, K, O, G0, G1
, G2
\n
"
);
printf
(
"arg1
0
:
scale (alpha)
\n
"
);
printf
(
"arg1
1
:
p_drop
\n
"
);
printf
(
"arg1
1
to 1
2
: input / output permute
\n
"
);
printf
(
"arg1
2
to 1
3
: input / output permute
\n
"
);
exit
(
0
);
exit
(
0
);
}
}
...
@@ -339,6 +341,7 @@ int run(int argc, char* argv[])
...
@@ -339,6 +341,7 @@ int run(int argc, char* argv[])
std
::
cout
<<
"O: "
<<
O
<<
std
::
endl
;
std
::
cout
<<
"O: "
<<
O
<<
std
::
endl
;
std
::
cout
<<
"G0: "
<<
G0
<<
std
::
endl
;
std
::
cout
<<
"G0: "
<<
G0
<<
std
::
endl
;
std
::
cout
<<
"G1: "
<<
G1
<<
std
::
endl
;
std
::
cout
<<
"G1: "
<<
G1
<<
std
::
endl
;
std
::
cout
<<
"G2: "
<<
G2
<<
std
::
endl
;
std
::
cout
<<
"alpha: "
<<
alpha
<<
std
::
endl
;
std
::
cout
<<
"alpha: "
<<
alpha
<<
std
::
endl
;
std
::
cout
<<
"input_permute: "
<<
input_permute
<<
std
::
endl
;
std
::
cout
<<
"input_permute: "
<<
input_permute
<<
std
::
endl
;
std
::
cout
<<
"output_permute: "
<<
output_permute
<<
std
::
endl
;
std
::
cout
<<
"output_permute: "
<<
output_permute
<<
std
::
endl
;
...
@@ -354,17 +357,17 @@ int run(int argc, char* argv[])
...
@@ -354,17 +357,17 @@ int run(int argc, char* argv[])
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
*
K
,
K
,
G1
*
K
,
1
}
// Q layout [G0, M, G1, K]
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
*
K
,
K
,
G1
*
K
,
1
}
// Q layout [G0, M, G1, K]
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
M
*
K
,
M
*
K
,
K
,
1
};
// Q layout [G0, G1, M, K]
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
M
*
K
,
M
*
K
,
K
,
1
};
// Q layout [G0, G1, M, K]
std
::
vector
<
ck
::
index_t
>
k_gs_ns_ks_lengths
{
G0
,
G
1
,
N
,
K
};
std
::
vector
<
ck
::
index_t
>
k_gs_ns_ks_lengths
{
G0
,
G
2
,
N
,
K
};
std
::
vector
<
ck
::
index_t
>
k_gs_ns_ks_strides
=
std
::
vector
<
ck
::
index_t
>
k_gs_ns_ks_strides
=
input_permute
input_permute
?
std
::
vector
<
ck
::
index_t
>
{
N
*
G
1
*
K
,
K
,
G
1
*
K
,
1
}
// K layout [G0, N, G
1
, K]
?
std
::
vector
<
ck
::
index_t
>
{
N
*
G
2
*
K
,
K
,
G
2
*
K
,
1
}
// K layout [G0, N, G
2
, K]
:
std
::
vector
<
ck
::
index_t
>
{
G
1
*
N
*
K
,
N
*
K
,
K
,
1
};
// K layout [G0, G
1
, N, K]
:
std
::
vector
<
ck
::
index_t
>
{
G
2
*
N
*
K
,
N
*
K
,
K
,
1
};
// K layout [G0, G
2
, N, K]
std
::
vector
<
ck
::
index_t
>
v_gs_os_ns_lengths
{
G0
,
G
1
,
O
,
N
};
std
::
vector
<
ck
::
index_t
>
v_gs_os_ns_lengths
{
G0
,
G
2
,
O
,
N
};
std
::
vector
<
ck
::
index_t
>
v_gs_os_ns_strides
=
std
::
vector
<
ck
::
index_t
>
v_gs_os_ns_strides
=
input_permute
input_permute
?
std
::
vector
<
ck
::
index_t
>
{
N
*
G
1
*
O
,
O
,
1
,
G
1
*
O
}
// V layout [G0, N, G
1
, O]
?
std
::
vector
<
ck
::
index_t
>
{
N
*
G
2
*
O
,
O
,
1
,
G
2
*
O
}
// V layout [G0, N, G
2
, O]
:
std
::
vector
<
ck
::
index_t
>
{
G
1
*
N
*
O
,
N
*
O
,
1
,
O
};
// V layout [G0, G
1
, N, O]
:
std
::
vector
<
ck
::
index_t
>
{
G
2
*
N
*
O
,
N
*
O
,
1
,
O
};
// V layout [G0, G
2
, N, O]
std
::
vector
<
ck
::
index_t
>
y_gs_ms_os_lengths
{
G0
,
G1
,
M
,
O
};
std
::
vector
<
ck
::
index_t
>
y_gs_ms_os_lengths
{
G0
,
G1
,
M
,
O
};
std
::
vector
<
ck
::
index_t
>
y_gs_ms_os_strides
=
std
::
vector
<
ck
::
index_t
>
y_gs_ms_os_strides
=
...
@@ -377,6 +380,18 @@ int run(int argc, char* argv[])
...
@@ -377,6 +380,18 @@ int run(int argc, char* argv[])
input_permute
input_permute
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
*
N
,
N
,
G1
*
N
,
1
}
// Z layout [G0, M, G1, N]
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
*
N
,
N
,
G1
*
N
,
1
}
// Z layout [G0, M, G1, N]
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
M
*
N
,
M
*
N
,
N
,
1
};
// Z layout [G0, G1, M, N]
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
M
*
N
,
M
*
N
,
N
,
1
};
// Z layout [G0, G1, M, N]
std
::
vector
<
ck
::
index_t
>
kgrad_gs_ns_ks_lengths
{
G0
,
G1
,
N
,
K
};
std
::
vector
<
ck
::
index_t
>
kgrad_gs_ns_ks_strides
=
input_permute
?
std
::
vector
<
ck
::
index_t
>
{
N
*
G1
*
K
,
K
,
G1
*
K
,
1
}
// KGrad layout [G0, N, G1, K]
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
N
*
K
,
N
*
K
,
K
,
1
};
// KGrad layout [G0, G1, N, K]
std
::
vector
<
ck
::
index_t
>
vgrad_gs_os_ns_lengths
{
G0
,
G1
,
O
,
N
};
std
::
vector
<
ck
::
index_t
>
vgrad_gs_os_ns_strides
=
input_permute
?
std
::
vector
<
ck
::
index_t
>
{
N
*
G1
*
O
,
O
,
1
,
G1
*
O
}
// VGrad layout [G0, N, G1, O]
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
N
*
O
,
N
*
O
,
1
,
O
};
// VGrad layout [G0, G1, N, O]
// The softmax stat log-sum-exp (LSE) is used to speed up softmax calculation in backward pass
// The softmax stat log-sum-exp (LSE) is used to speed up softmax calculation in backward pass
// Pi = exp(Si) / sum(exp(S0) + exp(S1) + ...)
// Pi = exp(Si) / sum(exp(S0) + exp(S1) + ...)
// = exp(Si) / exp(log(sum(exp() + ...)))
// = exp(Si) / exp(log(sum(exp() + ...)))
...
@@ -394,6 +409,8 @@ int run(int argc, char* argv[])
...
@@ -394,6 +409,8 @@ int run(int argc, char* argv[])
Tensor
<
InputDataType
>
ygrad_gs_ms_os
(
y_gs_ms_os_lengths
,
y_gs_ms_os_strides
);
Tensor
<
InputDataType
>
ygrad_gs_ms_os
(
y_gs_ms_os_lengths
,
y_gs_ms_os_strides
);
Tensor
<
LSEDataType
>
lse_gs_ms
(
lse_gs_ms_lengths
,
lse_gs_ms_strides
);
Tensor
<
LSEDataType
>
lse_gs_ms
(
lse_gs_ms_lengths
,
lse_gs_ms_strides
);
Tensor
<
DDataType
>
d_gs_ms
(
lse_gs_ms_lengths
,
lse_gs_ms_strides
);
Tensor
<
DDataType
>
d_gs_ms
(
lse_gs_ms_lengths
,
lse_gs_ms_strides
);
Tensor
<
OutputDataType
>
kgrad_gs_ns_ks
(
kgrad_gs_ns_ks_lengths
,
kgrad_gs_ns_ks_strides
);
Tensor
<
OutputDataType
>
vgrad_gs_os_ns
(
vgrad_gs_os_ns_lengths
,
vgrad_gs_os_ns_strides
);
std
::
cout
<<
"q_gs_ms_ks: "
<<
q_gs_ms_ks
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"q_gs_ms_ks: "
<<
q_gs_ms_ks
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"k_gs_ns_ks: "
<<
k_gs_ns_ks
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"k_gs_ns_ks: "
<<
k_gs_ns_ks
.
mDesc
<<
std
::
endl
;
...
@@ -402,6 +419,8 @@ int run(int argc, char* argv[])
...
@@ -402,6 +419,8 @@ int run(int argc, char* argv[])
std
::
cout
<<
"y_gs_ms_os: "
<<
y_gs_ms_os
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"y_gs_ms_os: "
<<
y_gs_ms_os
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"lse_gs_ms_os: "
<<
lse_gs_ms
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"lse_gs_ms_os: "
<<
lse_gs_ms
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"d_gs_ms_os: "
<<
d_gs_ms
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"d_gs_ms_os: "
<<
d_gs_ms
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"kgrad_gs_ns_ks: "
<<
kgrad_gs_ns_ks
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"vgrad_gs_os_ns: "
<<
vgrad_gs_os_ns
.
mDesc
<<
std
::
endl
;
z_gs_ms_ns
.
GenerateTensorValue
(
GeneratorTensor_1
<
InputDataType
>
{
0
});
z_gs_ms_ns
.
GenerateTensorValue
(
GeneratorTensor_1
<
InputDataType
>
{
0
});
switch
(
init_method
)
switch
(
init_method
)
...
@@ -478,10 +497,20 @@ int run(int argc, char* argv[])
...
@@ -478,10 +497,20 @@ int run(int argc, char* argv[])
q_gs_ms_ks
.
ForEach
(
q_gs_ms_ks
.
ForEach
(
[
&
](
auto
&
self
,
auto
idx
)
{
q_g_m_k
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
});
[
&
](
auto
&
self
,
auto
idx
)
{
q_g_m_k
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
});
k_gs_ns_ks
.
ForEach
(
k_g_n_k
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
[
&
](
auto
&
self
,
auto
idx
)
{
k_g_n_k
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
});
const
size_t
&
g0
=
idx
[
0
]
/
G1
;
v_gs_os_ns
.
ForEach
(
const
size_t
&
g1
=
idx
[
0
]
%
G1
;
[
&
](
auto
&
self
,
auto
idx
)
{
v_g_n_o
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
3
],
idx
[
2
])
=
self
(
idx
);
});
const
size_t
&
g2
=
g1
/
(
G1
/
G2
);
self
(
idx
)
=
k_gs_ns_ks
(
g0
,
g2
,
idx
[
1
],
idx
[
2
]);
});
v_g_n_o
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
const
size_t
&
g0
=
idx
[
0
]
/
G1
;
const
size_t
&
g1
=
idx
[
0
]
%
G1
;
const
size_t
&
g2
=
g1
/
(
G1
/
G2
);
self
(
idx
)
=
v_gs_os_ns
(
g0
,
g2
,
idx
[
2
],
idx
[
1
]);
});
// qkv gradients have the same descriptor as with qkv
// qkv gradients have the same descriptor as with qkv
DeviceMem
q_device_buf
(
sizeof
(
InputDataType
)
*
q_gs_ms_ks
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
q_device_buf
(
sizeof
(
InputDataType
)
*
q_gs_ms_ks
.
mDesc
.
GetElementSpaceSize
());
...
@@ -491,8 +520,8 @@ int run(int argc, char* argv[])
...
@@ -491,8 +520,8 @@ int run(int argc, char* argv[])
DeviceMem
y_device_buf
(
sizeof
(
InputDataType
)
*
y_gs_ms_os
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
y_device_buf
(
sizeof
(
InputDataType
)
*
y_gs_ms_os
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
lse_device_buf
(
sizeof
(
LSEDataType
)
*
lse_gs_ms
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
lse_device_buf
(
sizeof
(
LSEDataType
)
*
lse_gs_ms
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
qgrad_device_buf
(
sizeof
(
OutputDataType
)
*
q_gs_ms_ks
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
qgrad_device_buf
(
sizeof
(
OutputDataType
)
*
q_gs_ms_ks
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
kgrad_device_buf
(
sizeof
(
OutputDataType
)
*
k_gs_ns_ks
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
kgrad_device_buf
(
sizeof
(
OutputDataType
)
*
k
grad
_gs_ns_ks
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
vgrad_device_buf
(
sizeof
(
OutputDataType
)
*
v_gs_os_ns
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
vgrad_device_buf
(
sizeof
(
OutputDataType
)
*
v
grad
_gs_os_ns
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
ygrad_device_buf
(
sizeof
(
InputDataType
)
*
y_gs_ms_os
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
ygrad_device_buf
(
sizeof
(
InputDataType
)
*
y_gs_ms_os
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
d_device_buf
(
sizeof
(
DDataType
)
*
d_gs_ms
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
d_device_buf
(
sizeof
(
DDataType
)
*
d_gs_ms
.
mDesc
.
GetElementSpaceSize
());
...
@@ -533,6 +562,10 @@ int run(int argc, char* argv[])
...
@@ -533,6 +562,10 @@ int run(int argc, char* argv[])
y_gs_ms_os_lengths
,
y_gs_ms_os_lengths
,
y_gs_ms_os_strides
,
y_gs_ms_os_strides
,
lse_gs_ms_lengths
,
lse_gs_ms_lengths
,
kgrad_gs_ns_ks_lengths
,
kgrad_gs_ns_ks_strides
,
vgrad_gs_os_ns_lengths
,
vgrad_gs_os_ns_strides
,
{},
// std::array<std::vector<ck::index_t>, 1>{acc0_biases_gs_ms_ns_lengths},
{},
// std::array<std::vector<ck::index_t>, 1>{acc0_biases_gs_ms_ns_lengths},
{},
// std::array<std::vector<ck::index_t>, 1>{acc0_biases_gs_ms_ns_strides},
{},
// std::array<std::vector<ck::index_t>, 1>{acc0_biases_gs_ms_ns_strides},
{},
// std::array<std::vector<ck::index_t>, 1>{acc1_biases_gs_ms_os_lengths},
{},
// std::array<std::vector<ck::index_t>, 1>{acc1_biases_gs_ms_os_lengths},
...
@@ -581,6 +614,10 @@ int run(int argc, char* argv[])
...
@@ -581,6 +614,10 @@ int run(int argc, char* argv[])
y_gs_ms_os_lengths
,
y_gs_ms_os_lengths
,
y_gs_ms_os_strides
,
y_gs_ms_os_strides
,
lse_gs_ms_lengths
,
lse_gs_ms_lengths
,
kgrad_gs_ns_ks_lengths
,
kgrad_gs_ns_ks_strides
,
vgrad_gs_os_ns_lengths
,
vgrad_gs_os_ns_strides
,
{},
// std::array<std::vector<ck::index_t>, 1>{acc0_biases_gs_ms_ns_lengths},
{},
// std::array<std::vector<ck::index_t>, 1>{acc0_biases_gs_ms_ns_lengths},
{},
// std::array<std::vector<ck::index_t>, 1>{acc0_biases_gs_ms_ns_strides},
{},
// std::array<std::vector<ck::index_t>, 1>{acc0_biases_gs_ms_ns_strides},
{},
// std::array<std::vector<ck::index_t>, 1>{acc1_biases_gs_ms_os_lengths},
{},
// std::array<std::vector<ck::index_t>, 1>{acc1_biases_gs_ms_os_lengths},
...
@@ -763,12 +800,16 @@ int run(int argc, char* argv[])
...
@@ -763,12 +800,16 @@ int run(int argc, char* argv[])
#endif
#endif
Tensor
<
OutputDataType
>
qgrad_gs_ms_ks_host_result
(
q_gs_ms_ks_lengths
,
q_gs_ms_ks_strides
);
Tensor
<
OutputDataType
>
qgrad_gs_ms_ks_host_result
(
q_gs_ms_ks_lengths
,
q_gs_ms_ks_strides
);
Tensor
<
OutputDataType
>
kgrad_gs_ns_ks_host_result
(
k_gs_ns_ks_lengths
,
k_gs_ns_ks_strides
);
Tensor
<
OutputDataType
>
kgrad_gs_ns_ks_host_result
(
kgrad_gs_ns_ks_lengths
,
Tensor
<
OutputDataType
>
vgrad_gs_os_ns_host_result
(
v_gs_os_ns_lengths
,
v_gs_os_ns_strides
);
kgrad_gs_ns_ks_strides
);
Tensor
<
OutputDataType
>
vgrad_gs_os_ns_host_result
(
vgrad_gs_os_ns_lengths
,
vgrad_gs_os_ns_strides
);
Tensor
<
OutputDataType
>
qgrad_gs_ms_ks_device_result
(
q_gs_ms_ks_lengths
,
q_gs_ms_ks_strides
);
Tensor
<
OutputDataType
>
qgrad_gs_ms_ks_device_result
(
q_gs_ms_ks_lengths
,
q_gs_ms_ks_strides
);
Tensor
<
OutputDataType
>
kgrad_gs_ns_ks_device_result
(
k_gs_ns_ks_lengths
,
k_gs_ns_ks_strides
);
Tensor
<
OutputDataType
>
kgrad_gs_ns_ks_device_result
(
kgrad_gs_ns_ks_lengths
,
Tensor
<
OutputDataType
>
vgrad_gs_os_ns_device_result
(
v_gs_os_ns_lengths
,
v_gs_os_ns_strides
);
kgrad_gs_ns_ks_strides
);
Tensor
<
OutputDataType
>
vgrad_gs_os_ns_device_result
(
vgrad_gs_os_ns_lengths
,
vgrad_gs_os_ns_strides
);
qgrad_device_buf
.
FromDevice
(
qgrad_gs_ms_ks_device_result
.
mData
.
data
());
qgrad_device_buf
.
FromDevice
(
qgrad_gs_ms_ks_device_result
.
mData
.
data
());
kgrad_device_buf
.
FromDevice
(
kgrad_gs_ns_ks_device_result
.
mData
.
data
());
kgrad_device_buf
.
FromDevice
(
kgrad_gs_ns_ks_device_result
.
mData
.
data
());
...
...
example/32_batched_gemm_scale_softmax_gemm/grouped_multihead_attention_backward_v3.cpp
View file @
104aeabc
...
@@ -24,7 +24,7 @@ Kernel outputs:
...
@@ -24,7 +24,7 @@ Kernel outputs:
*/
*/
#define USING_MASK 0
#define USING_MASK 0
#define DIM
32
// DIM should be a multiple of 8.
#define DIM
128
// DIM should be a multiple of 8.
#include <iostream>
#include <iostream>
#include <numeric>
#include <numeric>
...
@@ -273,6 +273,7 @@ int run(int argc, char* argv[])
...
@@ -273,6 +273,7 @@ int run(int argc, char* argv[])
// y_g0_m_g1_o = permute(y_g0_g1_m_o, [0, 2, 1, 3])
// y_g0_m_g1_o = permute(y_g0_g1_m_o, [0, 2, 1, 3])
float
alpha
=
1.
f
/
std
::
sqrt
(
DIM
);
float
alpha
=
1.
f
/
std
::
sqrt
(
DIM
);
float
p_drop
=
0.0
;
float
p_drop
=
0.0
;
int
h_ratio
=
1
;
// G1 / G2
bool
input_permute
=
true
;
bool
input_permute
=
true
;
bool
output_permute
=
true
;
bool
output_permute
=
true
;
...
@@ -290,25 +291,26 @@ int run(int argc, char* argv[])
...
@@ -290,25 +291,26 @@ int run(int argc, char* argv[])
init_method
=
std
::
stoi
(
argv
[
2
]);
init_method
=
std
::
stoi
(
argv
[
2
]);
time_kernel
=
std
::
stoi
(
argv
[
3
]);
time_kernel
=
std
::
stoi
(
argv
[
3
]);
}
}
else
if
(
argc
==
7
)
else
if
(
argc
==
8
)
{
{
do_verification
=
std
::
stoi
(
argv
[
1
]);
do_verification
=
std
::
stoi
(
argv
[
1
]);
init_method
=
std
::
stoi
(
argv
[
2
]);
init_method
=
std
::
stoi
(
argv
[
2
]);
time_kernel
=
std
::
stoi
(
argv
[
3
]);
time_kernel
=
std
::
stoi
(
argv
[
3
]);
p_drop
=
std
::
stof
(
argv
[
4
]);
p_drop
=
std
::
stof
(
argv
[
4
]);
h_ratio
=
std
::
stof
(
argv
[
5
]);
input_permute
=
std
::
stoi
(
argv
[
5
]);
input_permute
=
std
::
stoi
(
argv
[
6
]);
output_permute
=
std
::
stoi
(
argv
[
6
]);
output_permute
=
std
::
stoi
(
argv
[
7
]);
}
}
else
else
{
{
printf
(
"arg1: verification (0=no, 1=yes)
\n
"
);
printf
(
"arg1: verification (0=no, 1=yes)
\n
"
);
printf
(
"arg2: initialization (0=no init, 1=integer value, 2=decimal value)
\n
"
);
printf
(
"arg2: initialization (0=no init, 1=integer value, 2=decimal value)
\n
"
);
printf
(
"arg3: time kernel (0=no, 1=yes)
\n
"
);
printf
(
"arg3: time kernel (0=no, 1=yes)
\n
"
);
printf
(
"arg4
to 11: M, N, K, O, G0, G1
\n
"
);
printf
(
"arg4
: p_drop
\n
"
);
printf
(
"arg
10: scale (alpha)
\n
"
);
printf
(
"arg
5: h_ratio
\n
"
);
printf
(
"arg
11
to
12
: input / output permute
\n
"
);
printf
(
"arg
6
to
7
: input / output permute
\n
"
);
exit
(
0
);
exit
(
0
);
}
}
...
@@ -376,24 +378,25 @@ int run(int argc, char* argv[])
...
@@ -376,24 +378,25 @@ int run(int argc, char* argv[])
int
K
=
DIM
;
int
K
=
DIM
;
int
O
=
DIM
;
int
O
=
DIM
;
int
G0
=
rand
()
%
4
+
1
;
int
G0
=
rand
()
%
4
+
1
;
int
G1
=
rand
()
%
4
+
1
;
int
G2
=
rand
()
%
4
+
1
;
int
G1
=
G2
*
h_ratio
;
std
::
vector
<
ck
::
index_t
>
q_gs_ms_ks_lengths
{
G0
,
G1
,
M
,
K
};
std
::
vector
<
ck
::
index_t
>
q_gs_ms_ks_lengths
{
G0
,
G1
,
M
,
K
};
std
::
vector
<
ck
::
index_t
>
q_gs_ms_ks_strides
=
std
::
vector
<
ck
::
index_t
>
q_gs_ms_ks_strides
=
input_permute
input_permute
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
*
K
,
K
,
G1
*
K
,
1
}
// Q layout [G0, M, G1, K]
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
*
K
,
K
,
G1
*
K
,
1
}
// Q layout [G0, M, G1, K]
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
M
*
K
,
M
*
K
,
K
,
1
};
// Q layout [G0, G1, M, K]
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
M
*
K
,
M
*
K
,
K
,
1
};
// Q layout [G0, G1, M, K]
std
::
vector
<
ck
::
index_t
>
k_gs_ns_ks_lengths
{
G0
,
G
1
,
N
,
K
};
std
::
vector
<
ck
::
index_t
>
k_gs_ns_ks_lengths
{
G0
,
G
2
,
N
,
K
};
std
::
vector
<
ck
::
index_t
>
k_gs_ns_ks_strides
=
std
::
vector
<
ck
::
index_t
>
k_gs_ns_ks_strides
=
input_permute
input_permute
?
std
::
vector
<
ck
::
index_t
>
{
N
*
G
1
*
K
,
K
,
G
1
*
K
,
1
}
// K layout [G0, N, G
1
, K]
?
std
::
vector
<
ck
::
index_t
>
{
N
*
G
2
*
K
,
K
,
G
2
*
K
,
1
}
// K layout [G0, N, G
2
, K]
:
std
::
vector
<
ck
::
index_t
>
{
G
1
*
N
*
K
,
N
*
K
,
K
,
1
};
// K layout [G0, G
1
, N, K]
:
std
::
vector
<
ck
::
index_t
>
{
G
2
*
N
*
K
,
N
*
K
,
K
,
1
};
// K layout [G0, G
2
, N, K]
std
::
vector
<
ck
::
index_t
>
v_gs_os_ns_lengths
{
G0
,
G
1
,
O
,
N
};
std
::
vector
<
ck
::
index_t
>
v_gs_os_ns_lengths
{
G0
,
G
2
,
O
,
N
};
std
::
vector
<
ck
::
index_t
>
v_gs_os_ns_strides
=
std
::
vector
<
ck
::
index_t
>
v_gs_os_ns_strides
=
input_permute
input_permute
?
std
::
vector
<
ck
::
index_t
>
{
N
*
G
1
*
O
,
O
,
1
,
G
1
*
O
}
// V layout [G0, N, G
1
, O]
?
std
::
vector
<
ck
::
index_t
>
{
N
*
G
2
*
O
,
O
,
1
,
G
2
*
O
}
// V layout [G0, N, G
2
, O]
:
std
::
vector
<
ck
::
index_t
>
{
G
1
*
N
*
O
,
N
*
O
,
1
,
O
};
// V layout [G0, G
1
, N, O]
:
std
::
vector
<
ck
::
index_t
>
{
G
2
*
N
*
O
,
N
*
O
,
1
,
O
};
// V layout [G0, G
2
, N, O]
std
::
vector
<
ck
::
index_t
>
y_gs_ms_os_lengths
{
G0
,
G1
,
M
,
O
};
std
::
vector
<
ck
::
index_t
>
y_gs_ms_os_lengths
{
G0
,
G1
,
M
,
O
};
std
::
vector
<
ck
::
index_t
>
y_gs_ms_os_strides
=
std
::
vector
<
ck
::
index_t
>
y_gs_ms_os_strides
=
...
@@ -406,6 +409,17 @@ int run(int argc, char* argv[])
...
@@ -406,6 +409,17 @@ int run(int argc, char* argv[])
input_permute
input_permute
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
*
N
,
N
,
G1
*
N
,
1
}
// Z layout [G0, M, G1, N]
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
*
N
,
N
,
G1
*
N
,
1
}
// Z layout [G0, M, G1, N]
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
M
*
N
,
M
*
N
,
N
,
1
};
// Z layout [G0, G1, M, N]
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
M
*
N
,
M
*
N
,
N
,
1
};
// Z layout [G0, G1, M, N]
std
::
vector
<
ck
::
index_t
>
kgrad_gs_ns_ks_lengths
{
G0
,
G1
,
N
,
K
};
std
::
vector
<
ck
::
index_t
>
kgrad_gs_ns_ks_strides
=
input_permute
?
std
::
vector
<
ck
::
index_t
>
{
N
*
G1
*
K
,
K
,
G1
*
K
,
1
}
// KGrad layout [G0, N, G1, K]
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
N
*
K
,
N
*
K
,
K
,
1
};
// KGrad layout [G0, G1, N, K]
std
::
vector
<
ck
::
index_t
>
vgrad_gs_os_ns_lengths
{
G0
,
G1
,
O
,
N
};
std
::
vector
<
ck
::
index_t
>
vgrad_gs_os_ns_strides
=
input_permute
?
std
::
vector
<
ck
::
index_t
>
{
N
*
G1
*
O
,
O
,
1
,
G1
*
O
}
// VGrad layout [G0, N, G1, O]
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
N
*
O
,
N
*
O
,
1
,
O
};
// VGrad layout [G0, G1, N, O]
// The softmax stat log-sum-exp (LSE) is used to speed up softmax calculation in backward
// The softmax stat log-sum-exp (LSE) is used to speed up softmax calculation in backward
// pass Pi = exp(Si) / sum(exp(S0) + exp(S1) + ...)
// pass Pi = exp(Si) / sum(exp(S0) + exp(S1) + ...)
// = exp(Si) / exp(log(sum(exp() + ...)))
// = exp(Si) / exp(log(sum(exp() + ...)))
...
@@ -427,6 +441,10 @@ int run(int argc, char* argv[])
...
@@ -427,6 +441,10 @@ int run(int argc, char* argv[])
y_gs_ms_os_strides
,
y_gs_ms_os_strides
,
lse_gs_ms_lengths
,
lse_gs_ms_lengths
,
lse_gs_ms_strides
,
lse_gs_ms_strides
,
kgrad_gs_ns_ks_lengths
,
kgrad_gs_ns_ks_strides
,
vgrad_gs_os_ns_lengths
,
vgrad_gs_os_ns_strides
,
{},
// std::array<std::vector<ck::index_t>, 1>{acc0_biases_gs_ms_ns_lengths},
{},
// std::array<std::vector<ck::index_t>, 1>{acc0_biases_gs_ms_ns_lengths},
{},
// std::array<std::vector<ck::index_t>, 1>{acc0_biases_gs_ms_ns_strides},
{},
// std::array<std::vector<ck::index_t>, 1>{acc0_biases_gs_ms_ns_strides},
{},
// std::array<std::vector<ck::index_t>, 1>{acc1_biases_gs_ms_os_lengths},
{},
// std::array<std::vector<ck::index_t>, 1>{acc1_biases_gs_ms_os_lengths},
...
@@ -451,6 +469,8 @@ int run(int argc, char* argv[])
...
@@ -451,6 +469,8 @@ int run(int argc, char* argv[])
Tensor
<
InputDataType
>
ygrad_gs_ms_os
(
y_gs_ms_os_lengths
,
y_gs_ms_os_strides
);
Tensor
<
InputDataType
>
ygrad_gs_ms_os
(
y_gs_ms_os_lengths
,
y_gs_ms_os_strides
);
Tensor
<
LSEDataType
>
lse_gs_ms
(
lse_gs_ms_lengths
,
lse_gs_ms_strides
);
Tensor
<
LSEDataType
>
lse_gs_ms
(
lse_gs_ms_lengths
,
lse_gs_ms_strides
);
Tensor
<
DDataType
>
d_gs_ms
(
lse_gs_ms_lengths
,
lse_gs_ms_strides
);
Tensor
<
DDataType
>
d_gs_ms
(
lse_gs_ms_lengths
,
lse_gs_ms_strides
);
Tensor
<
OutputDataType
>
kgrad_gs_ns_ks
(
kgrad_gs_ns_ks_lengths
,
kgrad_gs_ns_ks_strides
);
Tensor
<
OutputDataType
>
vgrad_gs_os_ns
(
vgrad_gs_os_ns_lengths
,
vgrad_gs_os_ns_strides
);
if
(
i
<
4
)
if
(
i
<
4
)
{
{
std
::
cout
<<
"q_gs_ms_ks: "
<<
q_gs_ms_ks
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"q_gs_ms_ks: "
<<
q_gs_ms_ks
.
mDesc
<<
std
::
endl
;
...
@@ -460,6 +480,8 @@ int run(int argc, char* argv[])
...
@@ -460,6 +480,8 @@ int run(int argc, char* argv[])
std
::
cout
<<
"y_gs_ms_os: "
<<
y_gs_ms_os
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"y_gs_ms_os: "
<<
y_gs_ms_os
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"lse_gs_ms_os: "
<<
lse_gs_ms
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"lse_gs_ms_os: "
<<
lse_gs_ms
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"d_gs_ms_os: "
<<
d_gs_ms
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"d_gs_ms_os: "
<<
d_gs_ms
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"kgrad_gs_ns_ks: "
<<
kgrad_gs_ns_ks
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"vgrad_gs_os_ns: "
<<
vgrad_gs_os_ns
.
mDesc
<<
std
::
endl
;
}
}
z_gs_ms_ns
.
GenerateTensorValue
(
GeneratorTensor_1
<
InputDataType
>
{
0
});
z_gs_ms_ns
.
GenerateTensorValue
(
GeneratorTensor_1
<
InputDataType
>
{
0
});
switch
(
init_method
)
switch
(
init_method
)
...
@@ -538,11 +560,19 @@ int run(int argc, char* argv[])
...
@@ -538,11 +560,19 @@ int run(int argc, char* argv[])
q_gs_ms_ks
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
q_gs_ms_ks
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
q_g_m_k
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
q_g_m_k
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
});
});
k_gs_ns_ks
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
k_g_n_k
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
k_g_n_k
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
const
size_t
&
g0
=
idx
[
0
]
/
G1
;
const
size_t
&
g1
=
idx
[
0
]
%
G1
;
const
size_t
&
g2
=
g1
/
h_ratio
;
self
(
idx
)
=
k_gs_ns_ks
(
g0
,
g2
,
idx
[
1
],
idx
[
2
]);
});
});
v_gs_os_ns
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
v_g_n_o
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
v_g_n_o
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
3
],
idx
[
2
])
=
self
(
idx
);
const
size_t
&
g0
=
idx
[
0
]
/
G1
;
const
size_t
&
g1
=
idx
[
0
]
%
G1
;
const
size_t
&
g2
=
g1
/
h_ratio
;
self
(
idx
)
=
v_gs_os_ns
(
g0
,
g2
,
idx
[
2
],
idx
[
1
]);
});
});
q_g_m_ks
.
push_back
(
q_g_m_k
);
q_g_m_ks
.
push_back
(
q_g_m_k
);
...
@@ -562,6 +592,8 @@ int run(int argc, char* argv[])
...
@@ -562,6 +592,8 @@ int run(int argc, char* argv[])
z_tensors
.
push_back
(
z_gs_ms_ns
);
z_tensors
.
push_back
(
z_gs_ms_ns
);
lse_tensors
.
push_back
(
lse_gs_ms
);
lse_tensors
.
push_back
(
lse_gs_ms
);
ygrad_tensors
.
push_back
(
ygrad_gs_ms_os
);
ygrad_tensors
.
push_back
(
ygrad_gs_ms_os
);
kgrad_tensors
.
push_back
(
kgrad_gs_ns_ks
);
vgrad_tensors
.
push_back
(
vgrad_gs_os_ns
);
q_tensors_device
.
emplace_back
(
q_tensors_device
.
emplace_back
(
std
::
make_unique
<
DeviceMem
>
(
sizeof
(
InputDataType
)
*
q_gs_ms_ks
.
GetElementSpaceSize
()));
std
::
make_unique
<
DeviceMem
>
(
sizeof
(
InputDataType
)
*
q_gs_ms_ks
.
GetElementSpaceSize
()));
k_tensors_device
.
emplace_back
(
k_tensors_device
.
emplace_back
(
...
@@ -578,10 +610,10 @@ int run(int argc, char* argv[])
...
@@ -578,10 +610,10 @@ int run(int argc, char* argv[])
std
::
make_unique
<
DeviceMem
>
(
sizeof
(
DDataType
)
*
d_gs_ms
.
GetElementSpaceSize
()));
std
::
make_unique
<
DeviceMem
>
(
sizeof
(
DDataType
)
*
d_gs_ms
.
GetElementSpaceSize
()));
qgrad_tensors_device
.
emplace_back
(
qgrad_tensors_device
.
emplace_back
(
std
::
make_unique
<
DeviceMem
>
(
sizeof
(
OutputDataType
)
*
q_gs_ms_ks
.
GetElementSpaceSize
()));
std
::
make_unique
<
DeviceMem
>
(
sizeof
(
OutputDataType
)
*
q_gs_ms_ks
.
GetElementSpaceSize
()));
kgrad_tensors_device
.
emplace_back
(
kgrad_tensors_device
.
emplace_back
(
std
::
make_unique
<
DeviceMem
>
(
std
::
make_unique
<
DeviceMem
>
(
sizeof
(
OutputDataType
)
*
k_gs_ns_ks
.
GetElementSpaceSize
()));
sizeof
(
OutputDataType
)
*
k
grad
_gs_ns_ks
.
GetElementSpaceSize
()));
vgrad_tensors_device
.
emplace_back
(
vgrad_tensors_device
.
emplace_back
(
std
::
make_unique
<
DeviceMem
>
(
std
::
make_unique
<
DeviceMem
>
(
sizeof
(
OutputDataType
)
*
v_gs_os_ns
.
GetElementSpaceSize
()));
sizeof
(
OutputDataType
)
*
v
grad
_gs_os_ns
.
GetElementSpaceSize
()));
ygrad_tensors_device
.
emplace_back
(
ygrad_tensors_device
.
emplace_back
(
std
::
make_unique
<
DeviceMem
>
(
sizeof
(
InputDataType
)
*
y_gs_ms_os
.
GetElementSpaceSize
()));
std
::
make_unique
<
DeviceMem
>
(
sizeof
(
InputDataType
)
*
y_gs_ms_os
.
GetElementSpaceSize
()));
q_tensors_device
.
back
()
->
ToDevice
(
q_gs_ms_ks
.
data
());
q_tensors_device
.
back
()
->
ToDevice
(
q_gs_ms_ks
.
data
());
...
@@ -625,6 +657,7 @@ int run(int argc, char* argv[])
...
@@ -625,6 +657,7 @@ int run(int argc, char* argv[])
QKVElementOp
{},
QKVElementOp
{},
YElementOp
{},
YElementOp
{},
p_drop
,
p_drop
,
h_ratio
,
std
::
tuple
<
unsigned
long
long
,
unsigned
long
long
>
(
seed
,
offset
));
std
::
tuple
<
unsigned
long
long
,
unsigned
long
long
>
(
seed
,
offset
));
DeviceMem
problem_desc_workspace
(
gemm
.
GetWorkSpaceSize
(
&
argument
));
DeviceMem
problem_desc_workspace
(
gemm
.
GetWorkSpaceSize
(
&
argument
));
...
@@ -674,6 +707,7 @@ int run(int argc, char* argv[])
...
@@ -674,6 +707,7 @@ int run(int argc, char* argv[])
QKVElementOp
{},
QKVElementOp
{},
YElementOp
{},
YElementOp
{},
p_drop
,
p_drop
,
h_ratio
,
std
::
tuple
<
unsigned
long
long
,
unsigned
long
long
>
(
seed
,
offset
));
std
::
tuple
<
unsigned
long
long
,
unsigned
long
long
>
(
seed
,
offset
));
DeviceMem
problem_desc_workspace_verify
(
gemm
.
GetWorkSpaceSize
(
&
argument
));
DeviceMem
problem_desc_workspace_verify
(
gemm
.
GetWorkSpaceSize
(
&
argument
));
gemm
.
SetWorkSpacePointer
(
&
argument
,
problem_desc_workspace_verify
.
GetDeviceBuffer
());
gemm
.
SetWorkSpacePointer
(
&
argument
,
problem_desc_workspace_verify
.
GetDeviceBuffer
());
...
@@ -687,7 +721,7 @@ int run(int argc, char* argv[])
...
@@ -687,7 +721,7 @@ int run(int argc, char* argv[])
for
(
std
::
size_t
i
=
0
;
i
<
group_count
;
i
++
)
for
(
std
::
size_t
i
=
0
;
i
<
group_count
;
i
++
)
{
{
int
G1
=
v
_tensors
[
i
].
GetLengths
()[
1
];
int
G1
=
q
_tensors
[
i
].
GetLengths
()[
1
];
// copy z matirx data form device
// copy z matirx data form device
z_tensors_device
[
i
]
->
FromDevice
(
z_tensors
[
i
].
mData
.
data
());
z_tensors_device
[
i
]
->
FromDevice
(
z_tensors
[
i
].
mData
.
data
());
z_tensors
[
i
].
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
z_tensors
[
i
].
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
...
@@ -724,8 +758,8 @@ int run(int argc, char* argv[])
...
@@ -724,8 +758,8 @@ int run(int argc, char* argv[])
for
(
std
::
size_t
i
=
0
;
i
<
group_count
;
i
++
)
for
(
std
::
size_t
i
=
0
;
i
<
group_count
;
i
++
)
{
{
int
G0
=
v
_tensors
[
i
].
GetLengths
()[
0
];
int
G0
=
q
_tensors
[
i
].
GetLengths
()[
0
];
int
G1
=
v
_tensors
[
i
].
GetLengths
()[
1
];
int
G1
=
q
_tensors
[
i
].
GetLengths
()[
1
];
int
O
=
v_tensors
[
i
].
GetLengths
()[
2
];
int
O
=
v_tensors
[
i
].
GetLengths
()[
2
];
int
N
=
v_tensors
[
i
].
GetLengths
()[
3
];
int
N
=
v_tensors
[
i
].
GetLengths
()[
3
];
int
M
=
q_tensors
[
i
].
GetLengths
()[
2
];
int
M
=
q_tensors
[
i
].
GetLengths
()[
2
];
...
@@ -783,17 +817,17 @@ int run(int argc, char* argv[])
...
@@ -783,17 +817,17 @@ int run(int argc, char* argv[])
Tensor
<
OutputDataType
>
qgrad_gs_ms_ks_host_result
(
q_tensors
[
i
].
GetLengths
(),
Tensor
<
OutputDataType
>
qgrad_gs_ms_ks_host_result
(
q_tensors
[
i
].
GetLengths
(),
q_tensors
[
i
].
GetStrides
());
q_tensors
[
i
].
GetStrides
());
Tensor
<
OutputDataType
>
kgrad_gs_ns_ks_host_result
(
k_tensors
[
i
].
GetLengths
(),
Tensor
<
OutputDataType
>
kgrad_gs_ns_ks_host_result
(
k
grad
_tensors
[
i
].
GetLengths
(),
k_tensors
[
i
].
GetStrides
());
k
grad
_tensors
[
i
].
GetStrides
());
Tensor
<
OutputDataType
>
vgrad_gs_os_ns_host_result
(
v_tensors
[
i
].
GetLengths
(),
Tensor
<
OutputDataType
>
vgrad_gs_os_ns_host_result
(
v
grad
_tensors
[
i
].
GetLengths
(),
v_tensors
[
i
].
GetStrides
());
v
grad
_tensors
[
i
].
GetStrides
());
Tensor
<
OutputDataType
>
qgrad_gs_ms_ks_device_result
(
q_tensors
[
i
].
GetLengths
(),
Tensor
<
OutputDataType
>
qgrad_gs_ms_ks_device_result
(
q_tensors
[
i
].
GetLengths
(),
q_tensors
[
i
].
GetStrides
());
q_tensors
[
i
].
GetStrides
());
Tensor
<
OutputDataType
>
kgrad_gs_ns_ks_device_result
(
k_tensors
[
i
].
GetLengths
(),
Tensor
<
OutputDataType
>
kgrad_gs_ns_ks_device_result
(
k
grad
_tensors
[
i
].
GetLengths
(),
k_tensors
[
i
].
GetStrides
());
k
grad
_tensors
[
i
].
GetStrides
());
Tensor
<
OutputDataType
>
vgrad_gs_os_ns_device_result
(
v_tensors
[
i
].
GetLengths
(),
Tensor
<
OutputDataType
>
vgrad_gs_os_ns_device_result
(
v
grad
_tensors
[
i
].
GetLengths
(),
v_tensors
[
i
].
GetStrides
());
v
grad
_tensors
[
i
].
GetStrides
());
qgrad_tensors_device
[
i
]
->
FromDevice
(
qgrad_gs_ms_ks_device_result
.
data
());
qgrad_tensors_device
[
i
]
->
FromDevice
(
qgrad_gs_ms_ks_device_result
.
data
());
kgrad_tensors_device
[
i
]
->
FromDevice
(
kgrad_gs_ns_ks_device_result
.
data
());
kgrad_tensors_device
[
i
]
->
FromDevice
(
kgrad_gs_ns_ks_device_result
.
data
());
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_bwd_xdl_cshuffle_qloop_light_v1.hpp
View file @
104aeabc
This diff is collapsed.
Click to expand it.
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_bwd_xdl_cshuffle_qloop_light_v2.hpp
View file @
104aeabc
...
@@ -132,14 +132,17 @@ __global__ void
...
@@ -132,14 +132,17 @@ __global__ void
const
CElementwiseOperation
c_element_op
,
const
CElementwiseOperation
c_element_op
,
const
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1
,
const
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1
,
const
BGridDesc_BK0_N_BK1
b_grid_desc_bk0_n_bk1
,
const
BGridDesc_BK0_N_BK1
b_grid_desc_bk0_n_bk1
,
const
BGridDesc_BK0_N_BK1
bgrad_grid_desc_bk0_n_bk1
,
const
D0GridDescriptor_M0_N0_M1_M2_N1_M3
d0_grid_desc_m0_n0_m1_m2_n1_m3
,
const
D0GridDescriptor_M0_N0_M1_M2_N1_M3
d0_grid_desc_m0_n0_m1_m2_n1_m3
,
const
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3
const
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
,
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
,
const
B1GridDesc_BK0_N_BK1
b1_grid_desc_bk0_n_bk1
,
const
B1GridDesc_BK0_N_BK1
b1_grid_desc_bk0_n_bk1
,
const
B1GridDesc_BK0_N_BK1
b1grad_grid_desc_bk0_n_bk1
,
const
LSEGridDescriptor_M
lse_grid_desc_m
,
const
LSEGridDescriptor_M
lse_grid_desc_m
,
const
YGradGridDesc_M0_O_M1
ygrad_grid_desc_m0_o_m1
,
const
YGradGridDesc_M0_O_M1
ygrad_grid_desc_m0_o_m1
,
const
Block2CTileMap
block_2_ctile_map
,
const
Block2CTileMap
block_2_ctile_map
,
const
index_t
batch_count
,
const
index_t
batch_count
,
const
index_t
h_ratio
,
const
index_t
nblock
,
const
index_t
nblock
,
const
ComputeBasePtrOfStridedBatch
compute_base_ptr_of_batch
,
const
ComputeBasePtrOfStridedBatch
compute_base_ptr_of_batch
,
const
C0MatrixMask
c0_matrix_mask
,
const
C0MatrixMask
c0_matrix_mask
,
...
@@ -155,21 +158,26 @@ __global__ void
...
@@ -155,21 +158,26 @@ __global__ void
const
index_t
num_blocks_per_batch
=
const
index_t
num_blocks_per_batch
=
__builtin_amdgcn_readfirstlane
(
get_grid_size
()
/
batch_count
);
__builtin_amdgcn_readfirstlane
(
get_grid_size
()
/
batch_count
);
const
index_t
g_idx
=
__builtin_amdgcn_readfirstlane
(
get_block_1d_id
()
/
num_blocks_per_batch
);
const
index_t
g_idx
=
__builtin_amdgcn_readfirstlane
(
get_block_1d_id
()
/
num_blocks_per_batch
);
const
index_t
gkv_idx
=
__builtin_amdgcn_readfirstlane
(
g_idx
/
h_ratio
);
// NOTE: assumes QKVY has the same layout as dQ/dK/dV/dY therefore being able to reuse batch
// NOTE: assumes QKVY has the same layout as dQ/dK/dV/dY therefore being able to reuse batch
// offsets
// offsets
const
long_index_t
a_batch_offset
=
__builtin_amdgcn_readfirstlane
(
const
long_index_t
a_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_base_ptr_of_batch
.
GetABasePtr
(
g_idx
)));
static_cast
<
long_index_t
>
(
compute_base_ptr_of_batch
.
GetABasePtr
(
g_idx
)));
const
long_index_t
b_batch_offset
=
__builtin_amdgcn_readfirstlane
(
const
long_index_t
b_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_base_ptr_of_batch
.
GetBBasePtr
(
g_idx
)));
static_cast
<
long_index_t
>
(
compute_base_ptr_of_batch
.
GetBBasePtr
(
g
kv
_idx
)));
const
long_index_t
z_batch_offset
=
__builtin_amdgcn_readfirstlane
(
const
long_index_t
z_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_base_ptr_of_batch
.
GetZBasePtr
(
g_idx
)));
static_cast
<
long_index_t
>
(
compute_base_ptr_of_batch
.
GetZBasePtr
(
g_idx
)));
const
long_index_t
b1_batch_offset
=
__builtin_amdgcn_readfirstlane
(
const
long_index_t
b1_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_base_ptr_of_batch
.
GetB1BasePtr
(
g_idx
)));
static_cast
<
long_index_t
>
(
compute_base_ptr_of_batch
.
GetB1BasePtr
(
g
kv
_idx
)));
const
long_index_t
c_batch_offset
=
__builtin_amdgcn_readfirstlane
(
const
long_index_t
c_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_base_ptr_of_batch
.
GetCBasePtr
(
g_idx
)));
static_cast
<
long_index_t
>
(
compute_base_ptr_of_batch
.
GetCBasePtr
(
g_idx
)));
const
long_index_t
lse_batch_offset
=
__builtin_amdgcn_readfirstlane
(
const
long_index_t
lse_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_base_ptr_of_batch
.
GetLSEBasePtr
(
g_idx
)));
static_cast
<
long_index_t
>
(
compute_base_ptr_of_batch
.
GetLSEBasePtr
(
g_idx
)));
const
long_index_t
bgrad_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_base_ptr_of_batch
.
GetBGradBasePtr
(
g_idx
)));
const
long_index_t
b1grad_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_base_ptr_of_batch
.
GetB1GradBasePtr
(
g_idx
)));
ck
::
philox
ph
(
seed
,
0
,
offset
);
ck
::
philox
ph
(
seed
,
0
,
offset
);
ZDataType
*
z_matrix_ptr
=
(
p_z_grid
==
nullptr
?
nullptr
:
p_z_grid
+
z_batch_offset
);
ZDataType
*
z_matrix_ptr
=
(
p_z_grid
==
nullptr
?
nullptr
:
p_z_grid
+
z_batch_offset
);
...
@@ -206,9 +214,9 @@ __global__ void
...
@@ -206,9 +214,9 @@ __global__ void
p_d_grid
+
lse_batch_offset
,
p_d_grid
+
lse_batch_offset
,
p_ygrad_grid
+
c_batch_offset
,
p_ygrad_grid
+
c_batch_offset
,
p_qgrad_grid
+
a_batch_offset
,
p_qgrad_grid
+
a_batch_offset
,
p_kgrad_grid
+
b_batch_offset
,
p_kgrad_grid
+
b
grad
_batch_offset
,
tmp_p_d0grad_grid
,
tmp_p_d0grad_grid
,
p_vgrad_grid
+
b1_batch_offset
,
p_vgrad_grid
+
b1
grad
_batch_offset
,
p_shared
,
p_shared
,
a_element_op
,
a_element_op
,
b_element_op
,
b_element_op
,
...
@@ -217,9 +225,11 @@ __global__ void
...
@@ -217,9 +225,11 @@ __global__ void
c_element_op
,
c_element_op
,
a_grid_desc_ak0_m_ak1
,
a_grid_desc_ak0_m_ak1
,
b_grid_desc_bk0_n_bk1
,
b_grid_desc_bk0_n_bk1
,
bgrad_grid_desc_bk0_n_bk1
,
d0_grid_desc_m0_n0_m1_m2_n1_m3
,
d0_grid_desc_m0_n0_m1_m2_n1_m3
,
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
,
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
,
b1_grid_desc_bk0_n_bk1
,
b1_grid_desc_bk0_n_bk1
,
b1grad_grid_desc_bk0_n_bk1
,
lse_grid_desc_m
,
lse_grid_desc_m
,
ygrad_grid_desc_m0_o_m1
,
ygrad_grid_desc_m0_o_m1
,
block_2_ctile_map
,
block_2_ctile_map
,
...
@@ -243,9 +253,9 @@ __global__ void
...
@@ -243,9 +253,9 @@ __global__ void
p_d_grid
+
lse_batch_offset
,
p_d_grid
+
lse_batch_offset
,
p_ygrad_grid
+
c_batch_offset
,
p_ygrad_grid
+
c_batch_offset
,
p_qgrad_grid
+
a_batch_offset
,
p_qgrad_grid
+
a_batch_offset
,
p_kgrad_grid
+
b_batch_offset
,
p_kgrad_grid
+
b
grad
_batch_offset
,
tmp_p_d0grad_grid
,
tmp_p_d0grad_grid
,
p_vgrad_grid
+
b1_batch_offset
,
p_vgrad_grid
+
b1
grad
_batch_offset
,
p_shared
,
p_shared
,
a_element_op
,
a_element_op
,
b_element_op
,
b_element_op
,
...
@@ -254,9 +264,11 @@ __global__ void
...
@@ -254,9 +264,11 @@ __global__ void
c_element_op
,
c_element_op
,
a_grid_desc_ak0_m_ak1
,
a_grid_desc_ak0_m_ak1
,
b_grid_desc_bk0_n_bk1
,
b_grid_desc_bk0_n_bk1
,
bgrad_grid_desc_bk0_n_bk1
,
d0_grid_desc_m0_n0_m1_m2_n1_m3
,
d0_grid_desc_m0_n0_m1_m2_n1_m3
,
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
,
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
,
b1_grid_desc_bk0_n_bk1
,
b1_grid_desc_bk0_n_bk1
,
b1grad_grid_desc_bk0_n_bk1
,
lse_grid_desc_m
,
lse_grid_desc_m
,
ygrad_grid_desc_m0_o_m1
,
ygrad_grid_desc_m0_o_m1
,
block_2_ctile_map
,
block_2_ctile_map
,
...
@@ -287,13 +299,16 @@ __global__ void
...
@@ -287,13 +299,16 @@ __global__ void
ignore
=
c_element_op
;
ignore
=
c_element_op
;
ignore
=
a_grid_desc_ak0_m_ak1
;
ignore
=
a_grid_desc_ak0_m_ak1
;
ignore
=
b_grid_desc_bk0_n_bk1
;
ignore
=
b_grid_desc_bk0_n_bk1
;
ignore
=
bgrad_grid_desc_bk0_n_bk1
;
ignore
=
d0_grid_desc_m0_n0_m1_m2_n1_m3
;
ignore
=
d0_grid_desc_m0_n0_m1_m2_n1_m3
;
ignore
=
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
;
ignore
=
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
;
ignore
=
b1_grid_desc_bk0_n_bk1
;
ignore
=
b1_grid_desc_bk0_n_bk1
;
ignore
=
b1grad_grid_desc_bk0_n_bk1
;
ignore
=
lse_grid_desc_m
;
ignore
=
lse_grid_desc_m
;
ignore
=
ygrad_grid_desc_m0_o_m1
;
ignore
=
ygrad_grid_desc_m0_o_m1
;
ignore
=
block_2_ctile_map
;
ignore
=
block_2_ctile_map
;
ignore
=
batch_count
;
ignore
=
batch_count
;
ignore
=
h_ratio
;
ignore
=
nblock
;
ignore
=
nblock
;
ignore
=
compute_base_ptr_of_batch
;
ignore
=
compute_base_ptr_of_batch
;
ignore
=
c0_matrix_mask
;
ignore
=
c0_matrix_mask
;
...
@@ -704,6 +719,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
...
@@ -704,6 +719,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
const
ZGridDesc_G_M_N
&
z_grid_desc_g_m_n
,
const
ZGridDesc_G_M_N
&
z_grid_desc_g_m_n
,
const
B1GridDesc_G_N_K
&
b1_grid_desc_g_n_k
,
const
B1GridDesc_G_N_K
&
b1_grid_desc_g_n_k
,
const
CGridDesc_G_M_N
&
c_grid_desc_g_m_n
,
const
CGridDesc_G_M_N
&
c_grid_desc_g_m_n
,
const
BGridDesc_G_N_K
&
bgrad_grid_desc_g_n_k
,
const
B1GridDesc_G_N_K
&
b1grad_grid_desc_g_n_k
,
index_t
BatchStrideLSE
)
index_t
BatchStrideLSE
)
:
a_grid_desc_g_m_k_
(
a_grid_desc_g_m_k
),
:
a_grid_desc_g_m_k_
(
a_grid_desc_g_m_k
),
b_grid_desc_g_n_k_
(
b_grid_desc_g_n_k
),
b_grid_desc_g_n_k_
(
b_grid_desc_g_n_k
),
...
@@ -711,6 +728,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
...
@@ -711,6 +728,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
z_grid_desc_g_m_n_
(
z_grid_desc_g_m_n
),
z_grid_desc_g_m_n_
(
z_grid_desc_g_m_n
),
b1_grid_desc_g_n_k_
(
b1_grid_desc_g_n_k
),
b1_grid_desc_g_n_k_
(
b1_grid_desc_g_n_k
),
c_grid_desc_g_m_n_
(
c_grid_desc_g_m_n
),
c_grid_desc_g_m_n_
(
c_grid_desc_g_m_n
),
bgrad_grid_desc_g_n_k_
(
bgrad_grid_desc_g_n_k
),
b1grad_grid_desc_g_n_k_
(
b1grad_grid_desc_g_n_k
),
BatchStrideLSE_
(
BatchStrideLSE
)
BatchStrideLSE_
(
BatchStrideLSE
)
{
{
}
}
...
@@ -729,6 +748,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
...
@@ -729,6 +748,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
{
{
return
d0_grid_desc_g_m_n_
.
CalculateOffset
(
make_multi_index
(
g_idx
,
0
,
0
));
return
d0_grid_desc_g_m_n_
.
CalculateOffset
(
make_multi_index
(
g_idx
,
0
,
0
));
}
}
__host__
__device__
constexpr
long_index_t
GetZBasePtr
(
index_t
g_idx
)
const
__host__
__device__
constexpr
long_index_t
GetZBasePtr
(
index_t
g_idx
)
const
{
{
return
z_grid_desc_g_m_n_
.
CalculateOffset
(
make_multi_index
(
g_idx
,
0
,
0
));
return
z_grid_desc_g_m_n_
.
CalculateOffset
(
make_multi_index
(
g_idx
,
0
,
0
));
...
@@ -749,6 +769,16 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
...
@@ -749,6 +769,16 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
return
g_idx
*
static_cast
<
long_index_t
>
(
BatchStrideLSE_
);
return
g_idx
*
static_cast
<
long_index_t
>
(
BatchStrideLSE_
);
}
}
__host__
__device__
constexpr
long_index_t
GetBGradBasePtr
(
index_t
g_idx
)
const
{
return
bgrad_grid_desc_g_n_k_
.
CalculateOffset
(
make_multi_index
(
g_idx
,
0
,
0
));
}
__host__
__device__
constexpr
long_index_t
GetB1GradBasePtr
(
index_t
g_idx
)
const
{
return
b1grad_grid_desc_g_n_k_
.
CalculateOffset
(
make_multi_index
(
g_idx
,
0
,
0
));
}
private:
private:
AGridDesc_G_M_K
a_grid_desc_g_m_k_
;
AGridDesc_G_M_K
a_grid_desc_g_m_k_
;
BGridDesc_G_N_K
b_grid_desc_g_n_k_
;
BGridDesc_G_N_K
b_grid_desc_g_n_k_
;
...
@@ -756,6 +786,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
...
@@ -756,6 +786,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
ZGridDesc_G_M_N
z_grid_desc_g_m_n_
;
ZGridDesc_G_M_N
z_grid_desc_g_m_n_
;
B1GridDesc_G_N_K
b1_grid_desc_g_n_k_
;
B1GridDesc_G_N_K
b1_grid_desc_g_n_k_
;
CGridDesc_G_M_N
c_grid_desc_g_m_n_
;
CGridDesc_G_M_N
c_grid_desc_g_m_n_
;
BGridDesc_G_N_K
bgrad_grid_desc_g_n_k_
;
B1GridDesc_G_N_K
b1grad_grid_desc_g_n_k_
;
index_t
BatchStrideLSE_
;
index_t
BatchStrideLSE_
;
};
};
...
@@ -874,6 +906,10 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
...
@@ -874,6 +906,10 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
const
std
::
vector
<
index_t
>&
c_gs_ms_gemm1ns_lengths
,
// c_gs_ms_os_lengths
const
std
::
vector
<
index_t
>&
c_gs_ms_gemm1ns_lengths
,
// c_gs_ms_os_lengths
const
std
::
vector
<
index_t
>&
c_gs_ms_gemm1ns_strides
,
// c_gs_ms_os_strides
const
std
::
vector
<
index_t
>&
c_gs_ms_gemm1ns_strides
,
// c_gs_ms_os_strides
const
std
::
vector
<
index_t
>&
lse_gs_ms_lengths
,
const
std
::
vector
<
index_t
>&
lse_gs_ms_lengths
,
const
std
::
vector
<
index_t
>&
bgrad_gs_ns_ks_lengths
,
const
std
::
vector
<
index_t
>&
bgrad_gs_ns_ks_strides
,
const
std
::
vector
<
index_t
>&
b1grad_gs_gemm1ns_gemm1ks_lengths
,
const
std
::
vector
<
index_t
>&
b1grad_gs_gemm1ns_gemm1ks_strides
,
const
std
::
vector
<
ck
::
index_t
>&
acc0_bias_gs_ms_ns_lengths
,
const
std
::
vector
<
ck
::
index_t
>&
acc0_bias_gs_ms_ns_lengths
,
const
std
::
vector
<
ck
::
index_t
>&
acc0_bias_gs_ms_ns_strides
,
const
std
::
vector
<
ck
::
index_t
>&
acc0_bias_gs_ms_ns_strides
,
const
std
::
vector
<
ck
::
index_t
>&
const
std
::
vector
<
ck
::
index_t
>&
...
@@ -904,9 +940,13 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
...
@@ -904,9 +940,13 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
DeviceOp
::
MakeAGridDescriptor_AK0_M_AK1
(
a_gs_ms_ks_lengths
,
a_gs_ms_ks_strides
)},
DeviceOp
::
MakeAGridDescriptor_AK0_M_AK1
(
a_gs_ms_ks_lengths
,
a_gs_ms_ks_strides
)},
b_grid_desc_bk0_n_bk1_
{
b_grid_desc_bk0_n_bk1_
{
DeviceOp
::
MakeBGridDescriptor_BK0_N_BK1
(
b_gs_ns_ks_lengths
,
b_gs_ns_ks_strides
)},
DeviceOp
::
MakeBGridDescriptor_BK0_N_BK1
(
b_gs_ns_ks_lengths
,
b_gs_ns_ks_strides
)},
bgrad_grid_desc_bk0_n_bk1_
{
DeviceOp
::
MakeBGridDescriptor_BK0_N_BK1
(
bgrad_gs_ns_ks_lengths
,
bgrad_gs_ns_ks_strides
)},
z_grid_desc_m_n_
{
MakeZGridDescriptor_M_N
(
z_gs_ms_ns_lengths
,
z_gs_ms_ns_strides
)},
z_grid_desc_m_n_
{
MakeZGridDescriptor_M_N
(
z_gs_ms_ns_lengths
,
z_gs_ms_ns_strides
)},
b1_grid_desc_bk0_n_bk1_
{
DeviceOp
::
MakeVGridDescriptor_O0_N_O1
(
b1_grid_desc_bk0_n_bk1_
{
DeviceOp
::
MakeVGridDescriptor_O0_N_O1
(
b1_gs_gemm1ns_gemm1ks_lengths
,
b1_gs_gemm1ns_gemm1ks_strides
)},
b1_gs_gemm1ns_gemm1ks_lengths
,
b1_gs_gemm1ns_gemm1ks_strides
)},
b1grad_grid_desc_bk0_n_bk1_
{
DeviceOp
::
MakeVGridDescriptor_O0_N_O1
(
b1grad_gs_gemm1ns_gemm1ks_lengths
,
b1grad_gs_gemm1ns_gemm1ks_strides
)},
y_grid_desc_m_o_
{
Transform
::
MakeCGridDescriptor_M_N
(
c_gs_ms_gemm1ns_lengths
,
y_grid_desc_m_o_
{
Transform
::
MakeCGridDescriptor_M_N
(
c_gs_ms_gemm1ns_lengths
,
c_gs_ms_gemm1ns_strides
)},
c_gs_ms_gemm1ns_strides
)},
d_y_grid_desc_m_o_
{
DTransform
::
MakeCGridDescriptor_M_N
(
c_gs_ms_gemm1ns_lengths
,
d_y_grid_desc_m_o_
{
DTransform
::
MakeCGridDescriptor_M_N
(
c_gs_ms_gemm1ns_lengths
,
...
@@ -927,6 +967,10 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
...
@@ -927,6 +967,10 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
c_gs_ms_gemm1ns_strides
)},
c_gs_ms_gemm1ns_strides
)},
z_grid_desc_g_m_n_
{
z_grid_desc_g_m_n_
{
Transform
::
MakeC0GridDescriptor_G_M_N
(
z_gs_ms_ns_lengths
,
z_gs_ms_ns_strides
)},
Transform
::
MakeC0GridDescriptor_G_M_N
(
z_gs_ms_ns_lengths
,
z_gs_ms_ns_strides
)},
bgrad_grid_desc_g_n_k_
{
Transform
::
MakeB0GridDescriptor_G_N_K
(
bgrad_gs_ns_ks_lengths
,
bgrad_gs_ns_ks_strides
)},
b1grad_grid_desc_g_n_k_
{
Transform
::
MakeB1GridDescriptor_G_N_K
(
b1grad_gs_gemm1ns_gemm1ks_lengths
,
b1grad_gs_gemm1ns_gemm1ks_strides
)},
block_2_ctile_map_
{
GridwiseGemm
::
MakeDefaultBlock2CTileMap
(
k_grid_desc_n_k_
)},
block_2_ctile_map_
{
GridwiseGemm
::
MakeDefaultBlock2CTileMap
(
k_grid_desc_n_k_
)},
d_block_2_ctile_map_
{
d_block_2_ctile_map_
{
GridwiseYDotYGrad
::
MakeDefaultBlock2CTileMap
(
d_y_grid_desc_m_o_
)},
GridwiseYDotYGrad
::
MakeDefaultBlock2CTileMap
(
d_y_grid_desc_m_o_
)},
...
@@ -950,6 +994,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
...
@@ -950,6 +994,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
c_mz_gemm1nz_strides_
{
c_gs_ms_gemm1ns_strides
[
NumDimG
+
NumDimM
-
1
],
c_mz_gemm1nz_strides_
{
c_gs_ms_gemm1ns_strides
[
NumDimG
+
NumDimM
-
1
],
c_gs_ms_gemm1ns_strides
[
NumDimG
+
NumDimM
+
NumDimO
-
1
]},
c_gs_ms_gemm1ns_strides
[
NumDimG
+
NumDimM
+
NumDimO
-
1
]},
batch_count_
{
c_grid_desc_g_m_n_
.
GetLength
(
I0
)},
batch_count_
{
c_grid_desc_g_m_n_
.
GetLength
(
I0
)},
h_ratio_
{
c_grid_desc_g_m_n_
.
GetLength
(
I0
)
/
b_grid_desc_g_n_k_
.
GetLength
(
I0
)},
p_drop_
{
p_drop
}
p_drop_
{
p_drop
}
{
{
// TODO: implement bias addition
// TODO: implement bias addition
...
@@ -979,6 +1024,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
...
@@ -979,6 +1024,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
z_grid_desc_g_m_n_
,
z_grid_desc_g_m_n_
,
b1_grid_desc_g_n_k_
,
b1_grid_desc_g_n_k_
,
c_grid_desc_g_m_n_
,
c_grid_desc_g_m_n_
,
bgrad_grid_desc_g_n_k_
,
b1grad_grid_desc_g_n_k_
,
type_convert
<
index_t
>
(
lse_grid_desc_m_
.
GetElementSpaceSize
()));
type_convert
<
index_t
>
(
lse_grid_desc_m_
.
GetElementSpaceSize
()));
seed_
=
std
::
get
<
0
>
(
seeds
);
seed_
=
std
::
get
<
0
>
(
seeds
);
...
@@ -1005,7 +1052,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
...
@@ -1005,7 +1052,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
<<
b_grid_desc_g_n_k_
.
GetLength
(
I1
)
<<
", "
<<
b_grid_desc_g_n_k_
.
GetLength
(
I1
)
<<
", "
<<
b_grid_desc_g_n_k_
.
GetLength
(
I2
)
<<
'\n'
;
<<
b_grid_desc_g_n_k_
.
GetLength
(
I2
)
<<
'\n'
;
// b_grid_desc_g_n_k_.Print();
// b_grid_desc_g_n_k_.Print();
std
::
cout
<<
"b1_grid_desc_g_
o_n
_: "
<<
b1_grid_desc_g_n_k_
.
GetLength
(
I0
)
<<
", "
std
::
cout
<<
"b1_grid_desc_g_
n_k
_: "
<<
b1_grid_desc_g_n_k_
.
GetLength
(
I0
)
<<
", "
<<
b1_grid_desc_g_n_k_
.
GetLength
(
I1
)
<<
", "
<<
b1_grid_desc_g_n_k_
.
GetLength
(
I1
)
<<
", "
<<
b1_grid_desc_g_n_k_
.
GetLength
(
I2
)
<<
'\n'
;
<<
b1_grid_desc_g_n_k_
.
GetLength
(
I2
)
<<
'\n'
;
// b1_grid_desc_g_n_k_.Print();
// b1_grid_desc_g_n_k_.Print();
...
@@ -1018,6 +1065,17 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
...
@@ -1018,6 +1065,17 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
std
::
cout
<<
"ygrad_grid_desc_m0_o_m1_: "
<<
ygrad_grid_desc_m0_o_m1_
.
GetLength
(
I0
)
std
::
cout
<<
"ygrad_grid_desc_m0_o_m1_: "
<<
ygrad_grid_desc_m0_o_m1_
.
GetLength
(
I0
)
<<
", "
<<
ygrad_grid_desc_m0_o_m1_
.
GetLength
(
I1
)
<<
", "
<<
", "
<<
ygrad_grid_desc_m0_o_m1_
.
GetLength
(
I1
)
<<
", "
<<
ygrad_grid_desc_m0_o_m1_
.
GetLength
(
I2
)
<<
'\n'
;
<<
ygrad_grid_desc_m0_o_m1_
.
GetLength
(
I2
)
<<
'\n'
;
std
::
cout
<<
"d0_grid_desc_g_m_n_: "
<<
d0_grid_desc_g_m_n_
.
GetLength
(
I0
)
<<
", "
<<
d0_grid_desc_g_m_n_
.
GetLength
(
I1
)
<<
", "
<<
d0_grid_desc_g_m_n_
.
GetLength
(
I2
)
<<
'\n'
;
std
::
cout
<<
"bgrad_grid_desc_g_n_k_: "
<<
bgrad_grid_desc_g_n_k_
.
GetLength
(
I0
)
<<
", "
<<
bgrad_grid_desc_g_n_k_
.
GetLength
(
I1
)
<<
", "
<<
bgrad_grid_desc_g_n_k_
.
GetLength
(
I2
)
<<
'\n'
;
// bgrad_grid_desc_g_n_k_.Print();
std
::
cout
<<
"b1grad_grid_desc_g_n_k_: "
<<
b1grad_grid_desc_g_n_k_
.
GetLength
(
I0
)
<<
", "
<<
b1grad_grid_desc_g_n_k_
.
GetLength
(
I1
)
<<
", "
<<
b1grad_grid_desc_g_n_k_
.
GetLength
(
I2
)
<<
'\n'
;
// b1grad_grid_desc_g_n_k_.Print();
}
}
// pointers
// pointers
...
@@ -1038,9 +1096,11 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
...
@@ -1038,9 +1096,11 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
// tensor descriptor
// tensor descriptor
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1_
;
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1_
;
BGridDesc_BK0_N_BK1
b_grid_desc_bk0_n_bk1_
;
BGridDesc_BK0_N_BK1
b_grid_desc_bk0_n_bk1_
;
BGridDesc_BK0_N_BK1
bgrad_grid_desc_bk0_n_bk1_
;
typename
GridwiseGemm
::
D0GridDescriptor_M0_N0_M1_M2_N1_M3
d0_grid_desc_m0_n0_m1_m2_n1_m3_
;
typename
GridwiseGemm
::
D0GridDescriptor_M0_N0_M1_M2_N1_M3
d0_grid_desc_m0_n0_m1_m2_n1_m3_
;
ZGridDesc_M_N
z_grid_desc_m_n_
;
ZGridDesc_M_N
z_grid_desc_m_n_
;
B1GridDesc_BK0_N_BK1
b1_grid_desc_bk0_n_bk1_
;
B1GridDesc_BK0_N_BK1
b1_grid_desc_bk0_n_bk1_
;
B1GridDesc_BK0_N_BK1
b1grad_grid_desc_bk0_n_bk1_
;
YGridDesc_M_O
y_grid_desc_m_o_
;
YGridDesc_M_O
y_grid_desc_m_o_
;
DYGridDesc_M_O
d_y_grid_desc_m_o_
;
DYGridDesc_M_O
d_y_grid_desc_m_o_
;
LSEGridDesc_M
lse_grid_desc_m_
;
LSEGridDesc_M
lse_grid_desc_m_
;
...
@@ -1055,6 +1115,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
...
@@ -1055,6 +1115,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
B1GridDesc_G_N_K
b1_grid_desc_g_n_k_
;
B1GridDesc_G_N_K
b1_grid_desc_g_n_k_
;
CGridDesc_G_M_N
c_grid_desc_g_m_n_
;
CGridDesc_G_M_N
c_grid_desc_g_m_n_
;
ZGridDesc_G_M_N
z_grid_desc_g_m_n_
;
ZGridDesc_G_M_N
z_grid_desc_g_m_n_
;
BGridDesc_G_N_K
bgrad_grid_desc_g_n_k_
;
B1GridDesc_G_N_K
b1grad_grid_desc_g_n_k_
;
typename
GridwiseGemm
::
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3
typename
GridwiseGemm
::
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3_
;
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3_
;
...
@@ -1083,6 +1145,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
...
@@ -1083,6 +1145,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
std
::
vector
<
index_t
>
c_mz_gemm1nz_strides_
;
std
::
vector
<
index_t
>
c_mz_gemm1nz_strides_
;
index_t
batch_count_
;
index_t
batch_count_
;
index_t
h_ratio_
;
ComputeBasePtrOfStridedBatch
compute_base_ptr_of_batch_
;
ComputeBasePtrOfStridedBatch
compute_base_ptr_of_batch_
;
float
p_drop_
;
float
p_drop_
;
...
@@ -1208,13 +1271,16 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
...
@@ -1208,13 +1271,16 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
arg
.
c_element_op_
,
arg
.
c_element_op_
,
arg
.
a_grid_desc_ak0_m_ak1_
,
arg
.
a_grid_desc_ak0_m_ak1_
,
arg
.
b_grid_desc_bk0_n_bk1_
,
arg
.
b_grid_desc_bk0_n_bk1_
,
arg
.
bgrad_grid_desc_bk0_n_bk1_
,
arg
.
d0_grid_desc_m0_n0_m1_m2_n1_m3_
,
arg
.
d0_grid_desc_m0_n0_m1_m2_n1_m3_
,
arg
.
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3_
,
arg
.
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3_
,
arg
.
b1_grid_desc_bk0_n_bk1_
,
arg
.
b1_grid_desc_bk0_n_bk1_
,
arg
.
b1grad_grid_desc_bk0_n_bk1_
,
arg
.
lse_grid_desc_m_
,
arg
.
lse_grid_desc_m_
,
arg
.
ygrad_grid_desc_m0_o_m1_
,
arg
.
ygrad_grid_desc_m0_o_m1_
,
arg
.
block_2_ctile_map_
,
arg
.
block_2_ctile_map_
,
arg
.
batch_count_
,
arg
.
batch_count_
,
arg
.
h_ratio_
,
arg
.
block_2_ctile_map_
.
CalculateGridSize
(
arg
.
k_grid_desc_n_k_
),
arg
.
block_2_ctile_map_
.
CalculateGridSize
(
arg
.
k_grid_desc_n_k_
),
arg
.
compute_base_ptr_of_batch_
,
arg
.
compute_base_ptr_of_batch_
,
arg
.
c0_matrix_mask_
,
arg
.
c0_matrix_mask_
,
...
@@ -1280,13 +1346,15 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
...
@@ -1280,13 +1346,15 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
// Check if C permute dimension matches GEMM + GEMM shape
// Check if C permute dimension matches GEMM + GEMM shape
const
index_t
c_g
=
arg
.
c_grid_desc_g_m_n_
.
GetLength
(
I0
);
// unpadded
const
index_t
c_g
=
arg
.
c_grid_desc_g_m_n_
.
GetLength
(
I0
);
// unpadded
const
index_t
b_g
=
arg
.
b_grid_desc_g_n_k_
.
GetLength
(
I0
);
const
index_t
c_m
=
arg
.
y_grid_desc_m_o_
.
GetLength
(
I0
);
const
index_t
c_m
=
arg
.
y_grid_desc_m_o_
.
GetLength
(
I0
);
const
index_t
c_gemm1n
=
arg
.
y_grid_desc_m_o_
.
GetLength
(
I1
);
const
index_t
c_gemm1n
=
arg
.
y_grid_desc_m_o_
.
GetLength
(
I1
);
const
index_t
a_m
=
arg
.
a_grid_desc_ak0_m_ak1_
.
GetLength
(
I1
);
const
index_t
a_m
=
arg
.
a_grid_desc_ak0_m_ak1_
.
GetLength
(
I1
);
const
index_t
b1_gemm1n
=
const
index_t
b1_gemm1n
=
arg
.
b1_grid_desc_bk0_n_bk1_
.
GetLength
(
I0
)
*
arg
.
b1_grid_desc_bk0_n_bk1_
.
GetLength
(
I2
);
arg
.
b1_grid_desc_bk0_n_bk1_
.
GetLength
(
I0
)
*
arg
.
b1_grid_desc_bk0_n_bk1_
.
GetLength
(
I2
);
if
(
!
(
c_g
==
arg
.
batch_count_
&&
c_m
==
a_m
&&
c_gemm1n
==
b1_gemm1n
))
if
(
!
(
c_g
==
arg
.
batch_count_
&&
c_m
==
a_m
&&
c_gemm1n
==
b1_gemm1n
&&
c_g
%
b_g
==
0
&&
b_g
<=
c_g
))
{
{
return
false
;
return
false
;
}
}
...
@@ -1380,6 +1448,10 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
...
@@ -1380,6 +1448,10 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
const
std
::
vector
<
index_t
>&
c_gs_ms_gemm1ns_lengths
,
// c_gs_ms_os_lengths
const
std
::
vector
<
index_t
>&
c_gs_ms_gemm1ns_lengths
,
// c_gs_ms_os_lengths
const
std
::
vector
<
index_t
>&
c_gs_ms_gemm1ns_strides
,
// c_gs_ms_os_strides
const
std
::
vector
<
index_t
>&
c_gs_ms_gemm1ns_strides
,
// c_gs_ms_os_strides
const
std
::
vector
<
index_t
>&
lse_gs_ms_lengths
,
const
std
::
vector
<
index_t
>&
lse_gs_ms_lengths
,
const
std
::
vector
<
index_t
>&
bgrad_gs_ns_ks_lengths
,
const
std
::
vector
<
index_t
>&
bgrad_gs_ns_ks_strides
,
const
std
::
vector
<
index_t
>&
b1grad_gs_gemm1ns_gemm1ks_lengths
,
const
std
::
vector
<
index_t
>&
b1grad_gs_gemm1ns_gemm1ks_strides
,
const
std
::
vector
<
ck
::
index_t
>&
acc0_bias_gs_ms_ns_lengths
,
const
std
::
vector
<
ck
::
index_t
>&
acc0_bias_gs_ms_ns_lengths
,
const
std
::
vector
<
ck
::
index_t
>&
acc0_bias_gs_ms_ns_strides
,
const
std
::
vector
<
ck
::
index_t
>&
acc0_bias_gs_ms_ns_strides
,
const
std
::
vector
<
ck
::
index_t
>&
const
std
::
vector
<
ck
::
index_t
>&
...
@@ -1420,6 +1492,10 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
...
@@ -1420,6 +1492,10 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
c_gs_ms_gemm1ns_lengths
,
// c_gs_ms_os_lengths
c_gs_ms_gemm1ns_lengths
,
// c_gs_ms_os_lengths
c_gs_ms_gemm1ns_strides
,
// c_gs_ms_os_strides
c_gs_ms_gemm1ns_strides
,
// c_gs_ms_os_strides
lse_gs_ms_lengths
,
lse_gs_ms_lengths
,
bgrad_gs_ns_ks_lengths
,
bgrad_gs_ns_ks_strides
,
b1grad_gs_gemm1ns_gemm1ks_lengths
,
b1grad_gs_gemm1ns_gemm1ks_strides
,
acc0_bias_gs_ms_ns_lengths
,
acc0_bias_gs_ms_ns_lengths
,
acc0_bias_gs_ms_ns_strides
,
acc0_bias_gs_ms_ns_strides
,
acc1_bias_gs_ms_gemm1ns_lengths
,
// acc1_bias_gs_ms_os_lengths
acc1_bias_gs_ms_gemm1ns_lengths
,
// acc1_bias_gs_ms_os_lengths
...
@@ -1464,6 +1540,10 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
...
@@ -1464,6 +1540,10 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
const
std
::
vector
<
index_t
>&
c_gs_ms_gemm1ns_lengths
,
// c_gs_ms_os_lengths
const
std
::
vector
<
index_t
>&
c_gs_ms_gemm1ns_lengths
,
// c_gs_ms_os_lengths
const
std
::
vector
<
index_t
>&
c_gs_ms_gemm1ns_strides
,
// c_gs_ms_os_strides
const
std
::
vector
<
index_t
>&
c_gs_ms_gemm1ns_strides
,
// c_gs_ms_os_strides
const
std
::
vector
<
index_t
>&
lse_gs_ms_lengths
,
const
std
::
vector
<
index_t
>&
lse_gs_ms_lengths
,
const
std
::
vector
<
index_t
>&
bgrad_gs_ns_ks_lengths
,
const
std
::
vector
<
index_t
>&
bgrad_gs_ns_ks_strides
,
const
std
::
vector
<
index_t
>&
b1grad_gs_gemm1ns_gemm1ks_lengths
,
const
std
::
vector
<
index_t
>&
b1grad_gs_gemm1ns_gemm1ks_strides
,
const
std
::
vector
<
ck
::
index_t
>&
acc0_bias_gs_ms_ns_lengths
,
const
std
::
vector
<
ck
::
index_t
>&
acc0_bias_gs_ms_ns_lengths
,
const
std
::
vector
<
ck
::
index_t
>&
acc0_bias_gs_ms_ns_strides
,
const
std
::
vector
<
ck
::
index_t
>&
acc0_bias_gs_ms_ns_strides
,
const
std
::
vector
<
ck
::
index_t
>&
const
std
::
vector
<
ck
::
index_t
>&
...
@@ -1505,6 +1585,10 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
...
@@ -1505,6 +1585,10 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
c_gs_ms_gemm1ns_lengths
,
// c_gs_ms_os_lengths
c_gs_ms_gemm1ns_lengths
,
// c_gs_ms_os_lengths
c_gs_ms_gemm1ns_strides
,
// c_gs_ms_os_strides
c_gs_ms_gemm1ns_strides
,
// c_gs_ms_os_strides
lse_gs_ms_lengths
,
lse_gs_ms_lengths
,
bgrad_gs_ns_ks_lengths
,
bgrad_gs_ns_ks_strides
,
b1grad_gs_gemm1ns_gemm1ks_lengths
,
b1grad_gs_gemm1ns_gemm1ks_strides
,
acc0_bias_gs_ms_ns_lengths
,
acc0_bias_gs_ms_ns_lengths
,
acc0_bias_gs_ms_ns_strides
,
acc0_bias_gs_ms_ns_strides
,
acc1_bias_gs_ms_gemm1ns_lengths
,
acc1_bias_gs_ms_gemm1ns_lengths
,
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_mha_bwd_xdl_cshuffle_qloop_light_v1.hpp
View file @
104aeabc
...
@@ -103,6 +103,7 @@ __global__ void
...
@@ -103,6 +103,7 @@ __global__ void
kernel_grouped_multihead_attention_backward_qloop_xdl_cshuffle_light_v1
(
kernel_grouped_multihead_attention_backward_qloop_xdl_cshuffle_light_v1
(
const
void
CK_CONSTANT_ADDRESS_SPACE
*
group_kernel_args
,
const
void
CK_CONSTANT_ADDRESS_SPACE
*
group_kernel_args
,
const
index_t
group_count
,
const
index_t
group_count
,
const
index_t
h_ratio
,
const
AElementwiseOperation
a_element_op
,
const
AElementwiseOperation
a_element_op
,
const
BElementwiseOperation
b_element_op
,
const
BElementwiseOperation
b_element_op
,
const
AccElementwiseOperation
acc_element_op
,
const
AccElementwiseOperation
acc_element_op
,
...
@@ -141,19 +142,26 @@ __global__ void
...
@@ -141,19 +142,26 @@ __global__ void
const
index_t
num_blocks_per_batch
=
arg_ptr
[
group_id
].
num_blocks_per_batch_
;
const
index_t
num_blocks_per_batch
=
arg_ptr
[
group_id
].
num_blocks_per_batch_
;
const
index_t
g_idx
=
__builtin_amdgcn_readfirstlane
(
const
index_t
g_idx
=
__builtin_amdgcn_readfirstlane
(
(
block_id
-
arg_ptr
[
group_id
].
block_start_
)
/
(
Deterministic
?
1
:
num_blocks_per_batch
));
(
block_id
-
arg_ptr
[
group_id
].
block_start_
)
/
(
Deterministic
?
1
:
num_blocks_per_batch
));
const
index_t
gkv_idx
=
__builtin_amdgcn_readfirstlane
(
g_idx
/
h_ratio
);
const
long_index_t
a_batch_offset
=
__builtin_amdgcn_readfirstlane
(
const
long_index_t
a_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
arg_ptr
[
group_id
].
compute_base_ptr_of_batch_
.
GetABasePtr
(
g_idx
)));
static_cast
<
long_index_t
>
(
arg_ptr
[
group_id
].
compute_base_ptr_of_batch_
.
GetABasePtr
(
g_idx
)));
const
long_index_t
b_batch_offset
=
__builtin_amdgcn_readfirstlane
(
const
long_index_t
b_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
static_cast
<
long_index_t
>
(
arg_ptr
[
group_id
].
compute_base_ptr_of_batch_
.
GetBBasePtr
(
g_idx
)));
arg_ptr
[
group_id
].
compute_base_ptr_of_batch_
.
GetBBasePtr
(
g
kv
_idx
)));
const
long_index_t
z_batch_offset
=
__builtin_amdgcn_readfirstlane
(
const
long_index_t
z_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
arg_ptr
[
group_id
].
compute_base_ptr_of_batch_
.
GetZBasePtr
(
g_idx
)));
static_cast
<
long_index_t
>
(
arg_ptr
[
group_id
].
compute_base_ptr_of_batch_
.
GetZBasePtr
(
g_idx
)));
const
long_index_t
b1_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
const
long_index_t
b1_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
arg_ptr
[
group_id
].
compute_base_ptr_of_batch_
.
GetB1BasePtr
(
g_idx
)));
arg_ptr
[
group_id
].
compute_base_ptr_of_batch_
.
GetB1BasePtr
(
g
kv
_idx
)));
const
long_index_t
c_batch_offset
=
__builtin_amdgcn_readfirstlane
(
const
long_index_t
c_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
arg_ptr
[
group_id
].
compute_base_ptr_of_batch_
.
GetCBasePtr
(
g_idx
)));
static_cast
<
long_index_t
>
(
arg_ptr
[
group_id
].
compute_base_ptr_of_batch_
.
GetCBasePtr
(
g_idx
)));
const
long_index_t
lse_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
const
long_index_t
lse_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
arg_ptr
[
group_id
].
compute_base_ptr_of_batch_
.
GetLSEBasePtr
(
g_idx
)));
arg_ptr
[
group_id
].
compute_base_ptr_of_batch_
.
GetLSEBasePtr
(
g_idx
)));
const
long_index_t
bgrad_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
arg_ptr
[
group_id
].
compute_base_ptr_of_batch_
.
GetBGradBasePtr
(
g_idx
)));
const
long_index_t
b1grad_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
arg_ptr
[
group_id
].
compute_base_ptr_of_batch_
.
GetB1GradBasePtr
(
g_idx
)));
const
index_t
global_thread_id
=
get_thread_global_1d_id
();
const
index_t
global_thread_id
=
get_thread_global_1d_id
();
ck
::
philox
ph
(
seed
,
global_thread_id
,
offset
);
ck
::
philox
ph
(
seed
,
global_thread_id
,
offset
);
...
@@ -168,6 +176,7 @@ __global__ void
...
@@ -168,6 +176,7 @@ __global__ void
const
long_index_t
d0_batch_offset
=
const
long_index_t
d0_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
arg_ptr
[
group_id
].
compute_base_ptr_of_batch_
.
GetD0BasePtr
(
g_idx
)));
arg_ptr
[
group_id
].
compute_base_ptr_of_batch_
.
GetD0BasePtr
(
g_idx
)));
if
(
arg_ptr
[
group_id
].
p_d0_grid_
!=
nullptr
)
if
(
arg_ptr
[
group_id
].
p_d0_grid_
!=
nullptr
)
tmp_p_d0_grid
=
arg_ptr
[
group_id
].
p_d0_grid_
+
d0_batch_offset
;
tmp_p_d0_grid
=
arg_ptr
[
group_id
].
p_d0_grid_
+
d0_batch_offset
;
if
(
arg_ptr
[
group_id
].
p_d0grad_grid_
)
if
(
arg_ptr
[
group_id
].
p_d0grad_grid_
)
...
@@ -187,9 +196,9 @@ __global__ void
...
@@ -187,9 +196,9 @@ __global__ void
arg_ptr
[
group_id
].
p_d_grid_
+
lse_batch_offset
,
arg_ptr
[
group_id
].
p_d_grid_
+
lse_batch_offset
,
arg_ptr
[
group_id
].
p_ygrad_grid_
+
c_batch_offset
,
arg_ptr
[
group_id
].
p_ygrad_grid_
+
c_batch_offset
,
arg_ptr
[
group_id
].
p_qgrad_grid_
+
a_batch_offset
,
arg_ptr
[
group_id
].
p_qgrad_grid_
+
a_batch_offset
,
arg_ptr
[
group_id
].
p_kgrad_grid_
+
b_batch_offset
,
arg_ptr
[
group_id
].
p_kgrad_grid_
+
b
grad
_batch_offset
,
tmp_p_d0grad_grid
,
tmp_p_d0grad_grid
,
arg_ptr
[
group_id
].
p_vgrad_grid_
+
b1_batch_offset
,
arg_ptr
[
group_id
].
p_vgrad_grid_
+
b1
grad
_batch_offset
,
p_shared
,
p_shared
,
a_element_op
,
a_element_op
,
b_element_op
,
b_element_op
,
...
@@ -198,9 +207,11 @@ __global__ void
...
@@ -198,9 +207,11 @@ __global__ void
c_element_op
,
c_element_op
,
arg_ptr
[
group_id
].
a_grid_desc_ak0_m_ak1_
,
arg_ptr
[
group_id
].
a_grid_desc_ak0_m_ak1_
,
arg_ptr
[
group_id
].
b_grid_desc_bk0_n_bk1_
,
arg_ptr
[
group_id
].
b_grid_desc_bk0_n_bk1_
,
arg_ptr
[
group_id
].
bgrad_grid_desc_bk0_n_bk1_
,
arg_ptr
[
group_id
].
d0_grid_desc_m0_n0_m1_m2_n1_m3_
,
arg_ptr
[
group_id
].
d0_grid_desc_m0_n0_m1_m2_n1_m3_
,
arg_ptr
[
group_id
].
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3_
,
arg_ptr
[
group_id
].
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3_
,
arg_ptr
[
group_id
].
b1_grid_desc_bk0_n_bk1_
,
arg_ptr
[
group_id
].
b1_grid_desc_bk0_n_bk1_
,
arg_ptr
[
group_id
].
b1grad_grid_desc_bk0_n_bk1_
,
arg_ptr
[
group_id
].
lse_grid_desc_m_
,
arg_ptr
[
group_id
].
lse_grid_desc_m_
,
arg_ptr
[
group_id
].
ygrad_grid_desc_o0_m_o1_
,
arg_ptr
[
group_id
].
ygrad_grid_desc_o0_m_o1_
,
arg_ptr
[
group_id
].
block_2_ctile_map_
,
arg_ptr
[
group_id
].
block_2_ctile_map_
,
...
@@ -225,9 +236,9 @@ __global__ void
...
@@ -225,9 +236,9 @@ __global__ void
arg_ptr
[
group_id
].
p_d_grid_
+
lse_batch_offset
,
arg_ptr
[
group_id
].
p_d_grid_
+
lse_batch_offset
,
arg_ptr
[
group_id
].
p_ygrad_grid_
+
c_batch_offset
,
arg_ptr
[
group_id
].
p_ygrad_grid_
+
c_batch_offset
,
arg_ptr
[
group_id
].
p_qgrad_grid_
+
a_batch_offset
,
arg_ptr
[
group_id
].
p_qgrad_grid_
+
a_batch_offset
,
arg_ptr
[
group_id
].
p_kgrad_grid_
+
b_batch_offset
,
arg_ptr
[
group_id
].
p_kgrad_grid_
+
b
grad
_batch_offset
,
tmp_p_d0grad_grid
,
tmp_p_d0grad_grid
,
arg_ptr
[
group_id
].
p_vgrad_grid_
+
b1_batch_offset
,
arg_ptr
[
group_id
].
p_vgrad_grid_
+
b1
grad
_batch_offset
,
p_shared
,
p_shared
,
a_element_op
,
a_element_op
,
b_element_op
,
b_element_op
,
...
@@ -236,9 +247,11 @@ __global__ void
...
@@ -236,9 +247,11 @@ __global__ void
c_element_op
,
c_element_op
,
arg_ptr
[
group_id
].
a_grid_desc_ak0_m_ak1_
,
arg_ptr
[
group_id
].
a_grid_desc_ak0_m_ak1_
,
arg_ptr
[
group_id
].
b_grid_desc_bk0_n_bk1_
,
arg_ptr
[
group_id
].
b_grid_desc_bk0_n_bk1_
,
arg_ptr
[
group_id
].
bgrad_grid_desc_bk0_n_bk1_
,
arg_ptr
[
group_id
].
d0_grid_desc_m0_n0_m1_m2_n1_m3_
,
arg_ptr
[
group_id
].
d0_grid_desc_m0_n0_m1_m2_n1_m3_
,
arg_ptr
[
group_id
].
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3_
,
arg_ptr
[
group_id
].
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3_
,
arg_ptr
[
group_id
].
b1_grid_desc_bk0_n_bk1_
,
arg_ptr
[
group_id
].
b1_grid_desc_bk0_n_bk1_
,
arg_ptr
[
group_id
].
b1grad_grid_desc_bk0_n_bk1_
,
arg_ptr
[
group_id
].
lse_grid_desc_m_
,
arg_ptr
[
group_id
].
lse_grid_desc_m_
,
arg_ptr
[
group_id
].
ygrad_grid_desc_o0_m_o1_
,
arg_ptr
[
group_id
].
ygrad_grid_desc_o0_m_o1_
,
arg_ptr
[
group_id
].
block_2_ctile_map_
,
arg_ptr
[
group_id
].
block_2_ctile_map_
,
...
@@ -253,6 +266,7 @@ __global__ void
...
@@ -253,6 +266,7 @@ __global__ void
#else
#else
ignore
=
group_kernel_args
;
ignore
=
group_kernel_args
;
ignore
=
group_count
;
ignore
=
group_count
;
ignore
=
h_ratio
;
ignore
=
a_element_op
;
ignore
=
a_element_op
;
ignore
=
b_element_op
;
ignore
=
b_element_op
;
ignore
=
acc_element_op
;
ignore
=
acc_element_op
;
...
@@ -366,6 +380,12 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
...
@@ -366,6 +380,12 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
std
::
vector
<
index_t
>
lse_gs_ms_lengths
;
std
::
vector
<
index_t
>
lse_gs_ms_lengths
;
std
::
vector
<
index_t
>
lse_gs_ms_strides
;
std
::
vector
<
index_t
>
lse_gs_ms_strides
;
std
::
vector
<
index_t
>
bgrad_gs_ns_ks_lengths
;
std
::
vector
<
index_t
>
bgrad_gs_ns_ks_strides
;
std
::
vector
<
index_t
>
b1grad_gs_gemm1ns_gemm1ks_lengths
;
std
::
vector
<
index_t
>
b1grad_gs_gemm1ns_gemm1ks_strides
;
std
::
vector
<
index_t
>
acc0_bias_gs_ms_ns_lengths
;
std
::
vector
<
index_t
>
acc0_bias_gs_ms_ns_lengths
;
std
::
vector
<
index_t
>
acc0_bias_gs_ms_ns_strides
;
std
::
vector
<
index_t
>
acc0_bias_gs_ms_ns_strides
;
...
@@ -576,7 +596,6 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
...
@@ -576,7 +596,6 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
static
auto
MakeD0GridDescriptor_M_N
(
const
std
::
vector
<
ck
::
index_t
>&
acc0_bias_gs_ms_ns_lengths
,
static
auto
MakeD0GridDescriptor_M_N
(
const
std
::
vector
<
ck
::
index_t
>&
acc0_bias_gs_ms_ns_lengths
,
const
std
::
vector
<
ck
::
index_t
>&
acc0_bias_gs_ms_ns_strides
)
const
std
::
vector
<
ck
::
index_t
>&
acc0_bias_gs_ms_ns_strides
)
{
{
return
Transform
::
MakeC0GridDescriptor_M_N
(
acc0_bias_gs_ms_ns_lengths
,
return
Transform
::
MakeC0GridDescriptor_M_N
(
acc0_bias_gs_ms_ns_lengths
,
acc0_bias_gs_ms_ns_strides
);
acc0_bias_gs_ms_ns_strides
);
}
}
...
@@ -585,7 +604,6 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
...
@@ -585,7 +604,6 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
MakeD0GridDescriptor_G_M_N
(
const
std
::
vector
<
ck
::
index_t
>&
acc0_bias_gs_ms_ns_lengths
,
MakeD0GridDescriptor_G_M_N
(
const
std
::
vector
<
ck
::
index_t
>&
acc0_bias_gs_ms_ns_lengths
,
const
std
::
vector
<
ck
::
index_t
>&
acc0_bias_gs_ms_ns_strides
)
const
std
::
vector
<
ck
::
index_t
>&
acc0_bias_gs_ms_ns_strides
)
{
{
return
Transform
::
MakeC0GridDescriptor_G_M_N
(
acc0_bias_gs_ms_ns_lengths
,
return
Transform
::
MakeC0GridDescriptor_G_M_N
(
acc0_bias_gs_ms_ns_lengths
,
acc0_bias_gs_ms_ns_strides
);
acc0_bias_gs_ms_ns_strides
);
}
}
...
@@ -625,7 +643,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
...
@@ -625,7 +643,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
using
D0GridDesc_G_M_N
=
decltype
(
MakeD0GridDescriptor_G_M_N
({},
{}));
using
D0GridDesc_G_M_N
=
decltype
(
MakeD0GridDescriptor_G_M_N
({},
{}));
using
B1GridDesc_G_N_K
=
decltype
(
Transform
::
MakeB1GridDescriptor_G_N_K
({},
{}));
using
B1GridDesc_G_N_K
=
decltype
(
Transform
::
MakeB1GridDescriptor_G_N_K
({},
{}));
using
CGridDesc_G_M_N
=
decltype
(
Transform
::
MakeCGridDescriptor_G_M_N
({},
{}));
using
CGridDesc_G_M_N
=
decltype
(
Transform
::
MakeCGridDescriptor_G_M_N
({},
{}));
using
ZGridDesc_G_M_N
=
decltype
(
Transform
::
MakeCGridDescriptor_G_M_N
({},
{}));
using
ZGridDesc_G_M_N
=
decltype
(
Transform
::
MakeC
0
GridDescriptor_G_M_N
({},
{}));
using
KGridDesc_N_K
=
decltype
(
Transform
::
MakeB0GridDescriptor_N_K
({},
{}));
using
KGridDesc_N_K
=
decltype
(
Transform
::
MakeB0GridDescriptor_N_K
({},
{}));
using
D0GridDesc_M_N
=
decltype
(
MakeD0GridDescriptor_M_N
({},
{}));
using
D0GridDesc_M_N
=
decltype
(
MakeD0GridDescriptor_M_N
({},
{}));
...
@@ -660,6 +678,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
...
@@ -660,6 +678,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
const
ZGridDesc_G_M_N
&
z_grid_desc_g_m_n
,
const
ZGridDesc_G_M_N
&
z_grid_desc_g_m_n
,
const
B1GridDesc_G_N_K
&
b1_grid_desc_g_n_k
,
const
B1GridDesc_G_N_K
&
b1_grid_desc_g_n_k
,
const
CGridDesc_G_M_N
&
c_grid_desc_g_m_n
,
const
CGridDesc_G_M_N
&
c_grid_desc_g_m_n
,
const
BGridDesc_G_N_K
&
bgrad_grid_desc_g_n_k
,
const
B1GridDesc_G_N_K
&
b1grad_grid_desc_g_n_k
,
index_t
batch_stride_lse
)
index_t
batch_stride_lse
)
:
a_grid_desc_g_m_k_
(
a_grid_desc_g_m_k
),
:
a_grid_desc_g_m_k_
(
a_grid_desc_g_m_k
),
b_grid_desc_g_n_k_
(
b_grid_desc_g_n_k
),
b_grid_desc_g_n_k_
(
b_grid_desc_g_n_k
),
...
@@ -667,6 +687,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
...
@@ -667,6 +687,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
z_grid_desc_g_m_n_
(
z_grid_desc_g_m_n
),
z_grid_desc_g_m_n_
(
z_grid_desc_g_m_n
),
b1_grid_desc_g_n_k_
(
b1_grid_desc_g_n_k
),
b1_grid_desc_g_n_k_
(
b1_grid_desc_g_n_k
),
c_grid_desc_g_m_n_
(
c_grid_desc_g_m_n
),
c_grid_desc_g_m_n_
(
c_grid_desc_g_m_n
),
bgrad_grid_desc_g_n_k_
(
bgrad_grid_desc_g_n_k
),
b1grad_grid_desc_g_n_k_
(
b1grad_grid_desc_g_n_k
),
batch_stride_lse_
(
batch_stride_lse
)
batch_stride_lse_
(
batch_stride_lse
)
{
{
}
}
...
@@ -706,6 +728,16 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
...
@@ -706,6 +728,16 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
return
g_idx
*
static_cast
<
long_index_t
>
(
batch_stride_lse_
);
return
g_idx
*
static_cast
<
long_index_t
>
(
batch_stride_lse_
);
}
}
__host__
__device__
constexpr
long_index_t
GetBGradBasePtr
(
index_t
g_idx
)
const
{
return
bgrad_grid_desc_g_n_k_
.
CalculateOffset
(
make_multi_index
(
g_idx
,
0
,
0
));
}
__host__
__device__
constexpr
long_index_t
GetB1GradBasePtr
(
index_t
g_idx
)
const
{
return
b1grad_grid_desc_g_n_k_
.
CalculateOffset
(
make_multi_index
(
g_idx
,
0
,
0
));
}
private:
private:
AGridDesc_G_M_K
a_grid_desc_g_m_k_
;
AGridDesc_G_M_K
a_grid_desc_g_m_k_
;
BGridDesc_G_N_K
b_grid_desc_g_n_k_
;
BGridDesc_G_N_K
b_grid_desc_g_n_k_
;
...
@@ -713,6 +745,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
...
@@ -713,6 +745,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
ZGridDesc_G_M_N
z_grid_desc_g_m_n_
;
ZGridDesc_G_M_N
z_grid_desc_g_m_n_
;
B1GridDesc_G_N_K
b1_grid_desc_g_n_k_
;
B1GridDesc_G_N_K
b1_grid_desc_g_n_k_
;
CGridDesc_G_M_N
c_grid_desc_g_m_n_
;
CGridDesc_G_M_N
c_grid_desc_g_m_n_
;
BGridDesc_G_N_K
bgrad_grid_desc_g_n_k_
;
B1GridDesc_G_N_K
b1grad_grid_desc_g_n_k_
;
index_t
batch_stride_lse_
;
index_t
batch_stride_lse_
;
};
};
...
@@ -817,9 +851,11 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
...
@@ -817,9 +851,11 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
// tensor descriptors for block/thread-wise copy
// tensor descriptors for block/thread-wise copy
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1_
;
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1_
;
BGridDesc_BK0_N_BK1
b_grid_desc_bk0_n_bk1_
;
BGridDesc_BK0_N_BK1
b_grid_desc_bk0_n_bk1_
;
BGridDesc_BK0_N_BK1
bgrad_grid_desc_bk0_n_bk1_
;
typename
GridwiseGemm
::
D0GridDescriptor_M0_N0_M1_M2_N1_M3
d0_grid_desc_m0_n0_m1_m2_n1_m3_
;
typename
GridwiseGemm
::
D0GridDescriptor_M0_N0_M1_M2_N1_M3
d0_grid_desc_m0_n0_m1_m2_n1_m3_
;
ZGridDesc_M_N
z_grid_desc_m_n_
;
ZGridDesc_M_N
z_grid_desc_m_n_
;
B1GridDesc_BK0_N_BK1
b1_grid_desc_bk0_n_bk1_
;
B1GridDesc_BK0_N_BK1
b1_grid_desc_bk0_n_bk1_
;
B1GridDesc_BK0_N_BK1
b1grad_grid_desc_bk0_n_bk1_
;
YGridDesc_M_O
y_grid_desc_m_o_
;
YGridDesc_M_O
y_grid_desc_m_o_
;
typename
GridwiseGemm
::
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3
typename
GridwiseGemm
::
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3
...
@@ -861,6 +897,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
...
@@ -861,6 +897,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
std
::
vector
<
index_t
>
c_mz_gemm1nz_strides_
;
std
::
vector
<
index_t
>
c_mz_gemm1nz_strides_
;
// for gridwise gemm check
// for gridwise gemm check
BGridDesc_G_N_K
b_grid_desc_g_n_k_
;
CGridDesc_G_M_N
c_grid_desc_g_m_n_
;
CGridDesc_G_M_N
c_grid_desc_g_m_n_
;
index_t
batch_count_
;
index_t
batch_count_
;
...
@@ -893,13 +930,15 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
...
@@ -893,13 +930,15 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
B1ElementwiseOperation
b1_element_op
,
B1ElementwiseOperation
b1_element_op
,
CElementwiseOperation
c_element_op
,
CElementwiseOperation
c_element_op
,
float
p_drop
,
float
p_drop
,
index_t
h_ratio
,
std
::
tuple
<
unsigned
long
long
,
unsigned
long
long
>
seeds
)
std
::
tuple
<
unsigned
long
long
,
unsigned
long
long
>
seeds
)
:
a_element_op_
{
a_element_op
},
:
a_element_op_
{
a_element_op
},
b_element_op_
{
b_element_op
},
b_element_op_
{
b_element_op
},
acc_element_op_
{
acc_element_op
},
acc_element_op_
{
acc_element_op
},
b1_element_op_
{
b1_element_op
},
b1_element_op_
{
b1_element_op
},
c_element_op_
{
c_element_op
},
c_element_op_
{
c_element_op
},
p_dropout_
{
p_drop
}
p_dropout_
{
p_drop
},
h_ratio_
{
h_ratio
}
{
{
seed_
=
std
::
get
<
0
>
(
seeds
);
seed_
=
std
::
get
<
0
>
(
seeds
);
offset_
=
std
::
get
<
1
>
(
seeds
);
offset_
=
std
::
get
<
1
>
(
seeds
);
...
@@ -960,6 +999,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
...
@@ -960,6 +999,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
problem_desc
.
a_gs_ms_ks_lengths
,
problem_desc
.
a_gs_ms_ks_strides
);
problem_desc
.
a_gs_ms_ks_lengths
,
problem_desc
.
a_gs_ms_ks_strides
);
const
auto
b_grid_desc_bk0_n_bk1
=
DeviceOp
::
MakeBGridDescriptor_BK0_N_BK1
(
const
auto
b_grid_desc_bk0_n_bk1
=
DeviceOp
::
MakeBGridDescriptor_BK0_N_BK1
(
problem_desc
.
b_gs_ns_ks_lengths
,
problem_desc
.
b_gs_ns_ks_strides
);
problem_desc
.
b_gs_ns_ks_lengths
,
problem_desc
.
b_gs_ns_ks_strides
);
const
auto
bgrad_grid_desc_bk0_n_bk1
=
DeviceOp
::
MakeBGridDescriptor_BK0_N_BK1
(
problem_desc
.
bgrad_gs_ns_ks_lengths
,
problem_desc
.
bgrad_gs_ns_ks_strides
);
std
::
vector
<
index_t
>
tmp_d0_gs_ms_ns_lengths
;
std
::
vector
<
index_t
>
tmp_d0_gs_ms_ns_lengths
;
std
::
vector
<
index_t
>
tmp_d0_gs_ms_ns_strides
;
std
::
vector
<
index_t
>
tmp_d0_gs_ms_ns_strides
;
...
@@ -982,6 +1023,9 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
...
@@ -982,6 +1023,9 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
const
auto
b1_grid_desc_bk0_n_bk1
=
DeviceOp
::
MakeVGridDescriptor_O0_N_O1
(
const
auto
b1_grid_desc_bk0_n_bk1
=
DeviceOp
::
MakeVGridDescriptor_O0_N_O1
(
problem_desc
.
b1_gs_gemm1ns_gemm1ks_lengths
,
problem_desc
.
b1_gs_gemm1ns_gemm1ks_lengths
,
problem_desc
.
b1_gs_gemm1ns_gemm1ks_strides
);
problem_desc
.
b1_gs_gemm1ns_gemm1ks_strides
);
const
auto
b1grad_grid_desc_bk0_n_bk1
=
DeviceOp
::
MakeVGridDescriptor_O0_N_O1
(
problem_desc
.
b1grad_gs_gemm1ns_gemm1ks_lengths
,
problem_desc
.
b1grad_gs_gemm1ns_gemm1ks_strides
);
const
auto
y_grid_desc_m_o
=
Transform
::
MakeCGridDescriptor_M_N
(
const
auto
y_grid_desc_m_o
=
Transform
::
MakeCGridDescriptor_M_N
(
problem_desc
.
c_gs_ms_gemm1ns_lengths
,
problem_desc
.
c_gs_ms_gemm1ns_strides
);
problem_desc
.
c_gs_ms_gemm1ns_lengths
,
problem_desc
.
c_gs_ms_gemm1ns_strides
);
...
@@ -1005,6 +1049,11 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
...
@@ -1005,6 +1049,11 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
problem_desc
.
b1_gs_gemm1ns_gemm1ks_strides
);
problem_desc
.
b1_gs_gemm1ns_gemm1ks_strides
);
const
auto
c_grid_desc_g_m_n
=
Transform
::
MakeCGridDescriptor_G_M_N
(
const
auto
c_grid_desc_g_m_n
=
Transform
::
MakeCGridDescriptor_G_M_N
(
problem_desc
.
c_gs_ms_gemm1ns_lengths
,
problem_desc
.
c_gs_ms_gemm1ns_strides
);
problem_desc
.
c_gs_ms_gemm1ns_lengths
,
problem_desc
.
c_gs_ms_gemm1ns_strides
);
const
auto
bgrad_grid_desc_g_n_k
=
Transform
::
MakeB0GridDescriptor_G_N_K
(
problem_desc
.
bgrad_gs_ns_ks_lengths
,
problem_desc
.
bgrad_gs_ns_ks_strides
);
const
auto
b1grad_grid_desc_g_n_k
=
Transform
::
MakeB1GridDescriptor_G_N_K
(
problem_desc
.
b1grad_gs_gemm1ns_gemm1ks_lengths
,
problem_desc
.
b1grad_gs_gemm1ns_gemm1ks_strides
);
typename
GridwiseGemm
::
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3
typename
GridwiseGemm
::
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
;
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
;
const
index_t
BlockStart
=
grid_size_
;
const
index_t
BlockStart
=
grid_size_
;
...
@@ -1027,6 +1076,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
...
@@ -1027,6 +1076,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
z_grid_desc_g_m_n
,
z_grid_desc_g_m_n
,
b1_grid_desc_g_n_k
,
b1_grid_desc_g_n_k
,
c_grid_desc_g_m_n
,
c_grid_desc_g_m_n
,
bgrad_grid_desc_g_n_k
,
b1grad_grid_desc_g_n_k
,
type_convert
<
index_t
>
(
lse_grid_desc_m
.
GetElementSpaceSize
()));
type_convert
<
index_t
>
(
lse_grid_desc_m
.
GetElementSpaceSize
()));
// C0 mask
// C0 mask
...
@@ -1073,9 +1124,11 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
...
@@ -1073,9 +1124,11 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
p_vgrad_grid
,
p_vgrad_grid
,
a_grid_desc_ak0_m_ak1
,
a_grid_desc_ak0_m_ak1
,
b_grid_desc_bk0_n_bk1
,
b_grid_desc_bk0_n_bk1
,
bgrad_grid_desc_bk0_n_bk1
,
d0_grid_desc_m0_n0_m1_m2_n1_m3
,
d0_grid_desc_m0_n0_m1_m2_n1_m3
,
z_grid_desc_m_n
,
z_grid_desc_m_n
,
b1_grid_desc_bk0_n_bk1
,
b1_grid_desc_bk0_n_bk1
,
b1grad_grid_desc_bk0_n_bk1
,
y_grid_desc_m_o
,
y_grid_desc_m_o
,
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
,
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
,
lse_grid_desc_m
,
lse_grid_desc_m
,
...
@@ -1119,6 +1172,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
...
@@ -1119,6 +1172,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
problem_desc
.
b1_gs_gemm1ns_gemm1ks_strides
[
NumDimG
+
NumDimO
+
NumDimN
-
1
]},
problem_desc
.
b1_gs_gemm1ns_gemm1ks_strides
[
NumDimG
+
NumDimO
+
NumDimN
-
1
]},
{
problem_desc
.
c_gs_ms_gemm1ns_strides
[
NumDimG
+
NumDimM
-
1
],
{
problem_desc
.
c_gs_ms_gemm1ns_strides
[
NumDimG
+
NumDimM
-
1
],
problem_desc
.
c_gs_ms_gemm1ns_strides
[
NumDimG
+
NumDimM
+
NumDimO
-
1
]},
problem_desc
.
c_gs_ms_gemm1ns_strides
[
NumDimG
+
NumDimM
+
NumDimO
-
1
]},
b_grid_desc_g_n_k
,
c_grid_desc_g_m_n
,
c_grid_desc_g_m_n
,
batch_count
,
batch_count
,
d0_n_length_stride
});
d0_n_length_stride
});
...
@@ -1145,6 +1199,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
...
@@ -1145,6 +1199,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
index_t
grid_size_
;
index_t
grid_size_
;
index_t
group_count_
;
index_t
group_count_
;
index_t
h_ratio_
;
std
::
vector
<
GroupKernelArg
>
group_kernel_args_
;
std
::
vector
<
GroupKernelArg
>
group_kernel_args_
;
std
::
vector
<
GroupDeviceArg
>
group_device_args_
;
std
::
vector
<
GroupDeviceArg
>
group_device_args_
;
...
@@ -1224,6 +1279,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
...
@@ -1224,6 +1279,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
0
,
0
,
cast_pointer_to_constant_address_space
(
arg
.
p_workspace_
),
cast_pointer_to_constant_address_space
(
arg
.
p_workspace_
),
arg
.
group_count_
,
arg
.
group_count_
,
arg
.
h_ratio_
,
arg
.
a_element_op_
,
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
b_element_op_
,
arg
.
acc_element_op_
,
arg
.
acc_element_op_
,
...
@@ -1292,13 +1348,15 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
...
@@ -1292,13 +1348,15 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
const
auto
&
device_arg
=
arg
.
group_device_args_
[
i
];
const
auto
&
device_arg
=
arg
.
group_device_args_
[
i
];
// Check if C permute dimension matches GEMM + GEMM shape
// Check if C permute dimension matches GEMM + GEMM shape
const
index_t
c_g
=
device_arg
.
c_grid_desc_g_m_n_
.
GetLength
(
I0
);
// unpadded
const
index_t
c_g
=
device_arg
.
c_grid_desc_g_m_n_
.
GetLength
(
I0
);
// unpadded
const
index_t
b_g
=
device_arg
.
b_grid_desc_g_n_k_
.
GetLength
(
I0
);
const
index_t
c_m
=
kernel_arg
.
y_grid_desc_m_o_
.
GetLength
(
I0
);
const
index_t
c_m
=
kernel_arg
.
y_grid_desc_m_o_
.
GetLength
(
I0
);
const
index_t
c_gemm1n
=
kernel_arg
.
y_grid_desc_m_o_
.
GetLength
(
I1
);
const
index_t
c_gemm1n
=
kernel_arg
.
y_grid_desc_m_o_
.
GetLength
(
I1
);
const
index_t
a_m
=
kernel_arg
.
a_grid_desc_ak0_m_ak1_
.
GetLength
(
I1
);
const
index_t
a_m
=
kernel_arg
.
a_grid_desc_ak0_m_ak1_
.
GetLength
(
I1
);
const
index_t
b1_gemm1n
=
kernel_arg
.
b1_grid_desc_bk0_n_bk1_
.
GetLength
(
I0
)
*
const
index_t
b1_gemm1n
=
kernel_arg
.
b1_grid_desc_bk0_n_bk1_
.
GetLength
(
I0
)
*
kernel_arg
.
b1_grid_desc_bk0_n_bk1_
.
GetLength
(
I2
);
kernel_arg
.
b1_grid_desc_bk0_n_bk1_
.
GetLength
(
I2
);
if
(
!
(
c_g
==
device_arg
.
batch_count_
&&
c_m
==
a_m
&&
c_gemm1n
==
b1_gemm1n
))
if
(
!
(
c_g
==
device_arg
.
batch_count_
&&
c_m
==
a_m
&&
c_gemm1n
==
b1_gemm1n
&&
c_g
%
b_g
==
0
&&
c_g
/
b_g
==
arg
.
h_ratio_
))
{
{
return
false
;
return
false
;
}
}
...
@@ -1395,6 +1453,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
...
@@ -1395,6 +1453,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
B1ElementwiseOperation
b1_element_op
,
B1ElementwiseOperation
b1_element_op
,
CElementwiseOperation
c_element_op
,
CElementwiseOperation
c_element_op
,
float
p_drop
,
float
p_drop
,
index_t
h_ratio
,
std
::
tuple
<
unsigned
long
long
,
unsigned
long
long
>
seeds
)
std
::
tuple
<
unsigned
long
long
,
unsigned
long
long
>
seeds
)
{
{
return
Argument
{
p_As
,
return
Argument
{
p_As
,
...
@@ -1419,6 +1478,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
...
@@ -1419,6 +1478,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
b1_element_op
,
b1_element_op
,
c_element_op
,
c_element_op
,
p_drop
,
p_drop
,
h_ratio
,
seeds
};
seeds
};
}
}
...
@@ -1449,6 +1509,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
...
@@ -1449,6 +1509,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
B1ElementwiseOperation
b1_element_op
,
B1ElementwiseOperation
b1_element_op
,
CElementwiseOperation
c_element_op
,
CElementwiseOperation
c_element_op
,
float
p_drop
,
float
p_drop
,
index_t
h_ratio
,
std
::
tuple
<
unsigned
long
long
,
unsigned
long
long
>
seeds
)
// override
std
::
tuple
<
unsigned
long
long
,
unsigned
long
long
>
seeds
)
// override
{
{
return
std
::
make_unique
<
Argument
>
(
p_As
,
return
std
::
make_unique
<
Argument
>
(
p_As
,
...
@@ -1473,6 +1534,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
...
@@ -1473,6 +1534,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
b1_element_op
,
b1_element_op
,
c_element_op
,
c_element_op
,
p_drop
,
p_drop
,
h_ratio
,
seeds
);
seeds
);
}
}
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_mha_bwd_xdl_cshuffle_qloop_light_v2.hpp
View file @
104aeabc
...
@@ -102,6 +102,7 @@ __global__ void
...
@@ -102,6 +102,7 @@ __global__ void
kernel_grouped_multihead_attention_backward_qloop_xdl_cshuffle_light_v2
(
kernel_grouped_multihead_attention_backward_qloop_xdl_cshuffle_light_v2
(
const
void
CK_CONSTANT_ADDRESS_SPACE
*
group_kernel_args
,
const
void
CK_CONSTANT_ADDRESS_SPACE
*
group_kernel_args
,
const
index_t
group_count
,
const
index_t
group_count
,
const
index_t
h_ratio
,
const
AElementwiseOperation
a_element_op
,
const
AElementwiseOperation
a_element_op
,
const
BElementwiseOperation
b_element_op
,
const
BElementwiseOperation
b_element_op
,
const
AccElementwiseOperation
acc_element_op
,
const
AccElementwiseOperation
acc_element_op
,
...
@@ -140,19 +141,26 @@ __global__ void
...
@@ -140,19 +141,26 @@ __global__ void
const
index_t
num_blocks_per_batch
=
arg_ptr
[
group_id
].
num_blocks_per_batch_
;
const
index_t
num_blocks_per_batch
=
arg_ptr
[
group_id
].
num_blocks_per_batch_
;
const
index_t
g_idx
=
__builtin_amdgcn_readfirstlane
(
const
index_t
g_idx
=
__builtin_amdgcn_readfirstlane
(
(
block_id
-
arg_ptr
[
group_id
].
block_start_
)
/
(
Deterministic
?
1
:
num_blocks_per_batch
));
(
block_id
-
arg_ptr
[
group_id
].
block_start_
)
/
(
Deterministic
?
1
:
num_blocks_per_batch
));
const
index_t
gkv_idx
=
__builtin_amdgcn_readfirstlane
(
g_idx
/
h_ratio
);
const
long_index_t
a_batch_offset
=
__builtin_amdgcn_readfirstlane
(
const
long_index_t
a_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
arg_ptr
[
group_id
].
compute_base_ptr_of_batch_
.
GetABasePtr
(
g_idx
)));
static_cast
<
long_index_t
>
(
arg_ptr
[
group_id
].
compute_base_ptr_of_batch_
.
GetABasePtr
(
g_idx
)));
const
long_index_t
b_batch_offset
=
__builtin_amdgcn_readfirstlane
(
const
long_index_t
b_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
static_cast
<
long_index_t
>
(
arg_ptr
[
group_id
].
compute_base_ptr_of_batch_
.
GetBBasePtr
(
g_idx
)));
arg_ptr
[
group_id
].
compute_base_ptr_of_batch_
.
GetBBasePtr
(
g
kv
_idx
)));
const
long_index_t
z_batch_offset
=
__builtin_amdgcn_readfirstlane
(
const
long_index_t
z_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
arg_ptr
[
group_id
].
compute_base_ptr_of_batch_
.
GetZBasePtr
(
g_idx
)));
static_cast
<
long_index_t
>
(
arg_ptr
[
group_id
].
compute_base_ptr_of_batch_
.
GetZBasePtr
(
g_idx
)));
const
long_index_t
b1_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
const
long_index_t
b1_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
arg_ptr
[
group_id
].
compute_base_ptr_of_batch_
.
GetB1BasePtr
(
g_idx
)));
arg_ptr
[
group_id
].
compute_base_ptr_of_batch_
.
GetB1BasePtr
(
g
kv
_idx
)));
const
long_index_t
c_batch_offset
=
__builtin_amdgcn_readfirstlane
(
const
long_index_t
c_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
arg_ptr
[
group_id
].
compute_base_ptr_of_batch_
.
GetCBasePtr
(
g_idx
)));
static_cast
<
long_index_t
>
(
arg_ptr
[
group_id
].
compute_base_ptr_of_batch_
.
GetCBasePtr
(
g_idx
)));
const
long_index_t
lse_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
const
long_index_t
lse_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
arg_ptr
[
group_id
].
compute_base_ptr_of_batch_
.
GetLSEBasePtr
(
g_idx
)));
arg_ptr
[
group_id
].
compute_base_ptr_of_batch_
.
GetLSEBasePtr
(
g_idx
)));
const
long_index_t
bgrad_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
arg_ptr
[
group_id
].
compute_base_ptr_of_batch_
.
GetBGradBasePtr
(
g_idx
)));
const
long_index_t
b1grad_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
arg_ptr
[
group_id
].
compute_base_ptr_of_batch_
.
GetB1GradBasePtr
(
g_idx
)));
const
index_t
global_thread_id
=
get_thread_global_1d_id
();
const
index_t
global_thread_id
=
get_thread_global_1d_id
();
ck
::
philox
ph
(
seed
,
global_thread_id
,
offset
);
ck
::
philox
ph
(
seed
,
global_thread_id
,
offset
);
...
@@ -166,7 +174,6 @@ __global__ void
...
@@ -166,7 +174,6 @@ __global__ void
const
long_index_t
d0_batch_offset
=
const
long_index_t
d0_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
arg_ptr
[
group_id
].
compute_base_ptr_of_batch_
.
GetD0BasePtr
(
g_idx
)));
arg_ptr
[
group_id
].
compute_base_ptr_of_batch_
.
GetD0BasePtr
(
g_idx
)));
if
(
arg_ptr
[
group_id
].
p_d0_grid_
!=
nullptr
)
if
(
arg_ptr
[
group_id
].
p_d0_grid_
!=
nullptr
)
tmp_p_d0_grid
=
arg_ptr
[
group_id
].
p_d0_grid_
+
d0_batch_offset
;
tmp_p_d0_grid
=
arg_ptr
[
group_id
].
p_d0_grid_
+
d0_batch_offset
;
if
(
arg_ptr
[
group_id
].
p_d0grad_grid_
)
if
(
arg_ptr
[
group_id
].
p_d0grad_grid_
)
...
@@ -187,9 +194,9 @@ __global__ void
...
@@ -187,9 +194,9 @@ __global__ void
arg_ptr
[
group_id
].
p_d_grid_
+
lse_batch_offset
,
arg_ptr
[
group_id
].
p_d_grid_
+
lse_batch_offset
,
arg_ptr
[
group_id
].
p_ygrad_grid_
+
c_batch_offset
,
arg_ptr
[
group_id
].
p_ygrad_grid_
+
c_batch_offset
,
arg_ptr
[
group_id
].
p_qgrad_grid_
+
a_batch_offset
,
arg_ptr
[
group_id
].
p_qgrad_grid_
+
a_batch_offset
,
arg_ptr
[
group_id
].
p_kgrad_grid_
+
b_batch_offset
,
arg_ptr
[
group_id
].
p_kgrad_grid_
+
b
grad
_batch_offset
,
tmp_p_d0grad_grid
,
tmp_p_d0grad_grid
,
arg_ptr
[
group_id
].
p_vgrad_grid_
+
b1_batch_offset
,
arg_ptr
[
group_id
].
p_vgrad_grid_
+
b1
grad
_batch_offset
,
p_shared
,
p_shared
,
a_element_op
,
a_element_op
,
b_element_op
,
b_element_op
,
...
@@ -198,9 +205,11 @@ __global__ void
...
@@ -198,9 +205,11 @@ __global__ void
c_element_op
,
c_element_op
,
arg_ptr
[
group_id
].
a_grid_desc_ak0_m_ak1_
,
arg_ptr
[
group_id
].
a_grid_desc_ak0_m_ak1_
,
arg_ptr
[
group_id
].
b_grid_desc_bk0_n_bk1_
,
arg_ptr
[
group_id
].
b_grid_desc_bk0_n_bk1_
,
arg_ptr
[
group_id
].
bgrad_grid_desc_bk0_n_bk1_
,
arg_ptr
[
group_id
].
d0_grid_desc_m0_n0_m1_m2_n1_m3_
,
arg_ptr
[
group_id
].
d0_grid_desc_m0_n0_m1_m2_n1_m3_
,
arg_ptr
[
group_id
].
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3_
,
arg_ptr
[
group_id
].
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3_
,
arg_ptr
[
group_id
].
b1_grid_desc_bk0_n_bk1_
,
arg_ptr
[
group_id
].
b1_grid_desc_bk0_n_bk1_
,
arg_ptr
[
group_id
].
b1grad_grid_desc_bk0_n_bk1_
,
arg_ptr
[
group_id
].
lse_grid_desc_m_
,
arg_ptr
[
group_id
].
lse_grid_desc_m_
,
arg_ptr
[
group_id
].
ygrad_grid_desc_m0_o_m1_
,
arg_ptr
[
group_id
].
ygrad_grid_desc_m0_o_m1_
,
arg_ptr
[
group_id
].
block_2_ctile_map_
,
arg_ptr
[
group_id
].
block_2_ctile_map_
,
...
@@ -225,9 +234,9 @@ __global__ void
...
@@ -225,9 +234,9 @@ __global__ void
arg_ptr
[
group_id
].
p_d_grid_
+
lse_batch_offset
,
arg_ptr
[
group_id
].
p_d_grid_
+
lse_batch_offset
,
arg_ptr
[
group_id
].
p_ygrad_grid_
+
c_batch_offset
,
arg_ptr
[
group_id
].
p_ygrad_grid_
+
c_batch_offset
,
arg_ptr
[
group_id
].
p_qgrad_grid_
+
a_batch_offset
,
arg_ptr
[
group_id
].
p_qgrad_grid_
+
a_batch_offset
,
arg_ptr
[
group_id
].
p_kgrad_grid_
+
b_batch_offset
,
arg_ptr
[
group_id
].
p_kgrad_grid_
+
b
grad
_batch_offset
,
tmp_p_d0grad_grid
,
tmp_p_d0grad_grid
,
arg_ptr
[
group_id
].
p_vgrad_grid_
+
b1_batch_offset
,
arg_ptr
[
group_id
].
p_vgrad_grid_
+
b1
grad
_batch_offset
,
p_shared
,
p_shared
,
a_element_op
,
a_element_op
,
b_element_op
,
b_element_op
,
...
@@ -236,9 +245,11 @@ __global__ void
...
@@ -236,9 +245,11 @@ __global__ void
c_element_op
,
c_element_op
,
arg_ptr
[
group_id
].
a_grid_desc_ak0_m_ak1_
,
arg_ptr
[
group_id
].
a_grid_desc_ak0_m_ak1_
,
arg_ptr
[
group_id
].
b_grid_desc_bk0_n_bk1_
,
arg_ptr
[
group_id
].
b_grid_desc_bk0_n_bk1_
,
arg_ptr
[
group_id
].
bgrad_grid_desc_bk0_n_bk1_
,
arg_ptr
[
group_id
].
d0_grid_desc_m0_n0_m1_m2_n1_m3_
,
arg_ptr
[
group_id
].
d0_grid_desc_m0_n0_m1_m2_n1_m3_
,
arg_ptr
[
group_id
].
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3_
,
arg_ptr
[
group_id
].
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3_
,
arg_ptr
[
group_id
].
b1_grid_desc_bk0_n_bk1_
,
arg_ptr
[
group_id
].
b1_grid_desc_bk0_n_bk1_
,
arg_ptr
[
group_id
].
b1grad_grid_desc_bk0_n_bk1_
,
arg_ptr
[
group_id
].
lse_grid_desc_m_
,
arg_ptr
[
group_id
].
lse_grid_desc_m_
,
arg_ptr
[
group_id
].
ygrad_grid_desc_m0_o_m1_
,
arg_ptr
[
group_id
].
ygrad_grid_desc_m0_o_m1_
,
arg_ptr
[
group_id
].
block_2_ctile_map_
,
arg_ptr
[
group_id
].
block_2_ctile_map_
,
...
@@ -253,6 +264,7 @@ __global__ void
...
@@ -253,6 +264,7 @@ __global__ void
#else
#else
ignore
=
group_kernel_args
;
ignore
=
group_kernel_args
;
ignore
=
group_count
;
ignore
=
group_count
;
ignore
=
h_ratio
;
ignore
=
a_element_op
;
ignore
=
a_element_op
;
ignore
=
b_element_op
;
ignore
=
b_element_op
;
ignore
=
acc_element_op
;
ignore
=
acc_element_op
;
...
@@ -373,6 +385,12 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
...
@@ -373,6 +385,12 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
std
::
vector
<
index_t
>
lse_gs_ms_lengths
;
std
::
vector
<
index_t
>
lse_gs_ms_lengths
;
std
::
vector
<
index_t
>
lse_gs_ms_strides
;
std
::
vector
<
index_t
>
lse_gs_ms_strides
;
std
::
vector
<
index_t
>
bgrad_gs_ns_ks_lengths
;
std
::
vector
<
index_t
>
bgrad_gs_ns_ks_strides
;
std
::
vector
<
index_t
>
b1grad_gs_gemm1ns_gemm1ks_lengths
;
std
::
vector
<
index_t
>
b1grad_gs_gemm1ns_gemm1ks_strides
;
std
::
vector
<
index_t
>
acc0_bias_gs_ms_ns_lengths
;
std
::
vector
<
index_t
>
acc0_bias_gs_ms_ns_lengths
;
std
::
vector
<
index_t
>
acc0_bias_gs_ms_ns_strides
;
std
::
vector
<
index_t
>
acc0_bias_gs_ms_ns_strides
;
...
@@ -639,7 +657,6 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
...
@@ -639,7 +657,6 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
static
auto
MakeD0GridDescriptor_M_N
(
const
std
::
vector
<
ck
::
index_t
>&
acc0_bias_gs_ms_ns_lengths
,
static
auto
MakeD0GridDescriptor_M_N
(
const
std
::
vector
<
ck
::
index_t
>&
acc0_bias_gs_ms_ns_lengths
,
const
std
::
vector
<
ck
::
index_t
>&
acc0_bias_gs_ms_ns_strides
)
const
std
::
vector
<
ck
::
index_t
>&
acc0_bias_gs_ms_ns_strides
)
{
{
return
Transform
::
MakeC0GridDescriptor_M_N
(
acc0_bias_gs_ms_ns_lengths
,
return
Transform
::
MakeC0GridDescriptor_M_N
(
acc0_bias_gs_ms_ns_lengths
,
acc0_bias_gs_ms_ns_strides
);
acc0_bias_gs_ms_ns_strides
);
}
}
...
@@ -648,7 +665,6 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
...
@@ -648,7 +665,6 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
MakeD0GridDescriptor_G_M_N
(
const
std
::
vector
<
ck
::
index_t
>&
acc0_bias_gs_ms_ns_lengths
,
MakeD0GridDescriptor_G_M_N
(
const
std
::
vector
<
ck
::
index_t
>&
acc0_bias_gs_ms_ns_lengths
,
const
std
::
vector
<
ck
::
index_t
>&
acc0_bias_gs_ms_ns_strides
)
const
std
::
vector
<
ck
::
index_t
>&
acc0_bias_gs_ms_ns_strides
)
{
{
return
Transform
::
MakeC0GridDescriptor_G_M_N
(
acc0_bias_gs_ms_ns_lengths
,
return
Transform
::
MakeC0GridDescriptor_G_M_N
(
acc0_bias_gs_ms_ns_lengths
,
acc0_bias_gs_ms_ns_strides
);
acc0_bias_gs_ms_ns_strides
);
}
}
...
@@ -723,6 +739,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
...
@@ -723,6 +739,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
const
ZGridDesc_G_M_N
&
z_grid_desc_g_m_n
,
const
ZGridDesc_G_M_N
&
z_grid_desc_g_m_n
,
const
B1GridDesc_G_N_K
&
b1_grid_desc_g_n_k
,
const
B1GridDesc_G_N_K
&
b1_grid_desc_g_n_k
,
const
CGridDesc_G_M_N
&
c_grid_desc_g_m_n
,
const
CGridDesc_G_M_N
&
c_grid_desc_g_m_n
,
const
BGridDesc_G_N_K
&
bgrad_grid_desc_g_n_k
,
const
B1GridDesc_G_N_K
&
b1grad_grid_desc_g_n_k
,
index_t
BatchStrideLSE
)
index_t
BatchStrideLSE
)
:
a_grid_desc_g_m_k_
(
a_grid_desc_g_m_k
),
:
a_grid_desc_g_m_k_
(
a_grid_desc_g_m_k
),
b_grid_desc_g_n_k_
(
b_grid_desc_g_n_k
),
b_grid_desc_g_n_k_
(
b_grid_desc_g_n_k
),
...
@@ -730,6 +748,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
...
@@ -730,6 +748,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
z_grid_desc_g_m_n_
(
z_grid_desc_g_m_n
),
z_grid_desc_g_m_n_
(
z_grid_desc_g_m_n
),
b1_grid_desc_g_n_k_
(
b1_grid_desc_g_n_k
),
b1_grid_desc_g_n_k_
(
b1_grid_desc_g_n_k
),
c_grid_desc_g_m_n_
(
c_grid_desc_g_m_n
),
c_grid_desc_g_m_n_
(
c_grid_desc_g_m_n
),
bgrad_grid_desc_g_n_k_
(
bgrad_grid_desc_g_n_k
),
b1grad_grid_desc_g_n_k_
(
b1grad_grid_desc_g_n_k
),
BatchStrideLSE_
(
BatchStrideLSE
)
BatchStrideLSE_
(
BatchStrideLSE
)
{
{
}
}
...
@@ -769,6 +789,16 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
...
@@ -769,6 +789,16 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
return
g_idx
*
static_cast
<
long_index_t
>
(
BatchStrideLSE_
);
return
g_idx
*
static_cast
<
long_index_t
>
(
BatchStrideLSE_
);
}
}
__host__
__device__
constexpr
long_index_t
GetBGradBasePtr
(
index_t
g_idx
)
const
{
return
bgrad_grid_desc_g_n_k_
.
CalculateOffset
(
make_multi_index
(
g_idx
,
0
,
0
));
}
__host__
__device__
constexpr
long_index_t
GetB1GradBasePtr
(
index_t
g_idx
)
const
{
return
b1grad_grid_desc_g_n_k_
.
CalculateOffset
(
make_multi_index
(
g_idx
,
0
,
0
));
}
private:
private:
AGridDesc_G_M_K
a_grid_desc_g_m_k_
;
AGridDesc_G_M_K
a_grid_desc_g_m_k_
;
BGridDesc_G_N_K
b_grid_desc_g_n_k_
;
BGridDesc_G_N_K
b_grid_desc_g_n_k_
;
...
@@ -776,6 +806,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
...
@@ -776,6 +806,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
ZGridDesc_G_M_N
z_grid_desc_g_m_n_
;
ZGridDesc_G_M_N
z_grid_desc_g_m_n_
;
B1GridDesc_G_N_K
b1_grid_desc_g_n_k_
;
B1GridDesc_G_N_K
b1_grid_desc_g_n_k_
;
CGridDesc_G_M_N
c_grid_desc_g_m_n_
;
CGridDesc_G_M_N
c_grid_desc_g_m_n_
;
BGridDesc_G_N_K
bgrad_grid_desc_g_n_k_
;
B1GridDesc_G_N_K
b1grad_grid_desc_g_n_k_
;
index_t
BatchStrideLSE_
;
index_t
BatchStrideLSE_
;
};
};
...
@@ -888,9 +920,11 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
...
@@ -888,9 +920,11 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
// tensor descriptors for block/thread-wise copy
// tensor descriptors for block/thread-wise copy
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1_
;
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1_
;
BGridDesc_BK0_N_BK1
b_grid_desc_bk0_n_bk1_
;
BGridDesc_BK0_N_BK1
b_grid_desc_bk0_n_bk1_
;
BGridDesc_BK0_N_BK1
bgrad_grid_desc_bk0_n_bk1_
;
typename
GridwiseGemm
::
D0GridDescriptor_M0_N0_M1_M2_N1_M3
d0_grid_desc_m0_n0_m1_m2_n1_m3_
;
typename
GridwiseGemm
::
D0GridDescriptor_M0_N0_M1_M2_N1_M3
d0_grid_desc_m0_n0_m1_m2_n1_m3_
;
ZGridDesc_M_N
z_grid_desc_m_n_
;
ZGridDesc_M_N
z_grid_desc_m_n_
;
B1GridDesc_BK0_N_BK1
b1_grid_desc_bk0_n_bk1_
;
B1GridDesc_BK0_N_BK1
b1_grid_desc_bk0_n_bk1_
;
B1GridDesc_BK0_N_BK1
b1grad_grid_desc_bk0_n_bk1_
;
YGridDesc_M_O
y_grid_desc_m_o_
;
YGridDesc_M_O
y_grid_desc_m_o_
;
typename
GridwiseGemm
::
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3
typename
GridwiseGemm
::
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3
...
@@ -932,6 +966,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
...
@@ -932,6 +966,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
std
::
vector
<
index_t
>
c_mz_gemm1nz_strides_
;
std
::
vector
<
index_t
>
c_mz_gemm1nz_strides_
;
// for gridwise gemm check
// for gridwise gemm check
BGridDesc_G_N_K
b_grid_desc_g_n_k_
;
CGridDesc_G_M_N
c_grid_desc_g_m_n_
;
CGridDesc_G_M_N
c_grid_desc_g_m_n_
;
index_t
batch_count_
;
index_t
batch_count_
;
...
@@ -964,13 +999,15 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
...
@@ -964,13 +999,15 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
B1ElementwiseOperation
b1_element_op
,
B1ElementwiseOperation
b1_element_op
,
CElementwiseOperation
c_element_op
,
CElementwiseOperation
c_element_op
,
float
p_drop
,
float
p_drop
,
index_t
h_ratio
,
std
::
tuple
<
unsigned
long
long
,
unsigned
long
long
>
seeds
)
std
::
tuple
<
unsigned
long
long
,
unsigned
long
long
>
seeds
)
:
a_element_op_
{
a_element_op
},
:
a_element_op_
{
a_element_op
},
b_element_op_
{
b_element_op
},
b_element_op_
{
b_element_op
},
acc_element_op_
{
acc_element_op
},
acc_element_op_
{
acc_element_op
},
b1_element_op_
{
b1_element_op
},
b1_element_op_
{
b1_element_op
},
c_element_op_
{
c_element_op
},
c_element_op_
{
c_element_op
},
p_dropout_
{
p_drop
}
p_dropout_
{
p_drop
},
h_ratio_
{
h_ratio
}
{
{
seed_
=
std
::
get
<
0
>
(
seeds
);
seed_
=
std
::
get
<
0
>
(
seeds
);
offset_
=
std
::
get
<
1
>
(
seeds
);
offset_
=
std
::
get
<
1
>
(
seeds
);
...
@@ -1031,6 +1068,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
...
@@ -1031,6 +1068,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
problem_desc
.
a_gs_ms_ks_lengths
,
problem_desc
.
a_gs_ms_ks_strides
);
problem_desc
.
a_gs_ms_ks_lengths
,
problem_desc
.
a_gs_ms_ks_strides
);
const
auto
b_grid_desc_bk0_n_bk1
=
DeviceOp
::
MakeBGridDescriptor_BK0_N_BK1
(
const
auto
b_grid_desc_bk0_n_bk1
=
DeviceOp
::
MakeBGridDescriptor_BK0_N_BK1
(
problem_desc
.
b_gs_ns_ks_lengths
,
problem_desc
.
b_gs_ns_ks_strides
);
problem_desc
.
b_gs_ns_ks_lengths
,
problem_desc
.
b_gs_ns_ks_strides
);
const
auto
bgrad_grid_desc_bk0_n_bk1
=
DeviceOp
::
MakeBGridDescriptor_BK0_N_BK1
(
problem_desc
.
bgrad_gs_ns_ks_lengths
,
problem_desc
.
bgrad_gs_ns_ks_strides
);
std
::
vector
<
index_t
>
tmp_d0_gs_ms_ns_lengths
;
std
::
vector
<
index_t
>
tmp_d0_gs_ms_ns_lengths
;
std
::
vector
<
index_t
>
tmp_d0_gs_ms_ns_strides
;
std
::
vector
<
index_t
>
tmp_d0_gs_ms_ns_strides
;
...
@@ -1053,6 +1092,9 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
...
@@ -1053,6 +1092,9 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
const
auto
b1_grid_desc_bk0_n_bk1
=
DeviceOp
::
MakeVGridDescriptor_O0_N_O1
(
const
auto
b1_grid_desc_bk0_n_bk1
=
DeviceOp
::
MakeVGridDescriptor_O0_N_O1
(
problem_desc
.
b1_gs_gemm1ns_gemm1ks_lengths
,
problem_desc
.
b1_gs_gemm1ns_gemm1ks_lengths
,
problem_desc
.
b1_gs_gemm1ns_gemm1ks_strides
);
problem_desc
.
b1_gs_gemm1ns_gemm1ks_strides
);
const
auto
b1grad_grid_desc_bk0_n_bk1
=
DeviceOp
::
MakeVGridDescriptor_O0_N_O1
(
problem_desc
.
b1grad_gs_gemm1ns_gemm1ks_lengths
,
problem_desc
.
b1grad_gs_gemm1ns_gemm1ks_strides
);
const
auto
y_grid_desc_m_o
=
Transform
::
MakeCGridDescriptor_M_N
(
const
auto
y_grid_desc_m_o
=
Transform
::
MakeCGridDescriptor_M_N
(
problem_desc
.
c_gs_ms_gemm1ns_lengths
,
problem_desc
.
c_gs_ms_gemm1ns_strides
);
problem_desc
.
c_gs_ms_gemm1ns_lengths
,
problem_desc
.
c_gs_ms_gemm1ns_strides
);
...
@@ -1076,6 +1118,11 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
...
@@ -1076,6 +1118,11 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
problem_desc
.
b1_gs_gemm1ns_gemm1ks_strides
);
problem_desc
.
b1_gs_gemm1ns_gemm1ks_strides
);
const
auto
c_grid_desc_g_m_n
=
Transform
::
MakeCGridDescriptor_G_M_N
(
const
auto
c_grid_desc_g_m_n
=
Transform
::
MakeCGridDescriptor_G_M_N
(
problem_desc
.
c_gs_ms_gemm1ns_lengths
,
problem_desc
.
c_gs_ms_gemm1ns_strides
);
problem_desc
.
c_gs_ms_gemm1ns_lengths
,
problem_desc
.
c_gs_ms_gemm1ns_strides
);
const
auto
bgrad_grid_desc_g_n_k
=
Transform
::
MakeB0GridDescriptor_G_N_K
(
problem_desc
.
bgrad_gs_ns_ks_lengths
,
problem_desc
.
bgrad_gs_ns_ks_strides
);
const
auto
b1grad_grid_desc_g_n_k
=
Transform
::
MakeB1GridDescriptor_G_N_K
(
problem_desc
.
b1grad_gs_gemm1ns_gemm1ks_lengths
,
problem_desc
.
b1grad_gs_gemm1ns_gemm1ks_strides
);
typename
GridwiseGemm
::
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3
typename
GridwiseGemm
::
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
;
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
;
const
index_t
BlockStart
=
grid_size_
;
const
index_t
BlockStart
=
grid_size_
;
...
@@ -1098,6 +1145,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
...
@@ -1098,6 +1145,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
z_grid_desc_g_m_n
,
z_grid_desc_g_m_n
,
b1_grid_desc_g_n_k
,
b1_grid_desc_g_n_k
,
c_grid_desc_g_m_n
,
c_grid_desc_g_m_n
,
bgrad_grid_desc_g_n_k
,
b1grad_grid_desc_g_n_k
,
type_convert
<
index_t
>
(
lse_grid_desc_m
.
GetElementSpaceSize
()));
type_convert
<
index_t
>
(
lse_grid_desc_m
.
GetElementSpaceSize
()));
// C0 mask
// C0 mask
...
@@ -1144,9 +1193,11 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
...
@@ -1144,9 +1193,11 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
p_vgrad_grid
,
p_vgrad_grid
,
a_grid_desc_ak0_m_ak1
,
a_grid_desc_ak0_m_ak1
,
b_grid_desc_bk0_n_bk1
,
b_grid_desc_bk0_n_bk1
,
bgrad_grid_desc_bk0_n_bk1
,
d0_grid_desc_m0_n0_m1_m2_n1_m3
,
d0_grid_desc_m0_n0_m1_m2_n1_m3
,
z_grid_desc_m_n
,
z_grid_desc_m_n
,
b1_grid_desc_bk0_n_bk1
,
b1_grid_desc_bk0_n_bk1
,
b1grad_grid_desc_bk0_n_bk1
,
y_grid_desc_m_o
,
y_grid_desc_m_o
,
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
,
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
,
lse_grid_desc_m
,
lse_grid_desc_m
,
...
@@ -1190,6 +1241,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
...
@@ -1190,6 +1241,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
problem_desc
.
b1_gs_gemm1ns_gemm1ks_strides
[
NumDimG
+
NumDimO
+
NumDimN
-
1
]},
problem_desc
.
b1_gs_gemm1ns_gemm1ks_strides
[
NumDimG
+
NumDimO
+
NumDimN
-
1
]},
{
problem_desc
.
c_gs_ms_gemm1ns_strides
[
NumDimG
+
NumDimM
-
1
],
{
problem_desc
.
c_gs_ms_gemm1ns_strides
[
NumDimG
+
NumDimM
-
1
],
problem_desc
.
c_gs_ms_gemm1ns_strides
[
NumDimG
+
NumDimM
+
NumDimO
-
1
]},
problem_desc
.
c_gs_ms_gemm1ns_strides
[
NumDimG
+
NumDimM
+
NumDimO
-
1
]},
b_grid_desc_g_n_k
,
c_grid_desc_g_m_n
,
c_grid_desc_g_m_n
,
batch_count
,
batch_count
,
d0_n_length_stride
});
d0_n_length_stride
});
...
@@ -1216,6 +1268,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
...
@@ -1216,6 +1268,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
index_t
grid_size_
;
index_t
grid_size_
;
index_t
group_count_
;
index_t
group_count_
;
index_t
h_ratio_
;
std
::
vector
<
GroupKernelArg
>
group_kernel_args_
;
std
::
vector
<
GroupKernelArg
>
group_kernel_args_
;
std
::
vector
<
GroupDeviceArg
>
group_device_args_
;
std
::
vector
<
GroupDeviceArg
>
group_device_args_
;
...
@@ -1294,6 +1347,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
...
@@ -1294,6 +1347,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
0
,
0
,
cast_pointer_to_constant_address_space
(
arg
.
p_workspace_
),
cast_pointer_to_constant_address_space
(
arg
.
p_workspace_
),
arg
.
group_count_
,
arg
.
group_count_
,
arg
.
h_ratio_
,
arg
.
a_element_op_
,
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
b_element_op_
,
arg
.
acc_element_op_
,
arg
.
acc_element_op_
,
...
@@ -1362,13 +1416,15 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
...
@@ -1362,13 +1416,15 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
const
auto
&
device_arg
=
arg
.
group_device_args_
[
i
];
const
auto
&
device_arg
=
arg
.
group_device_args_
[
i
];
// Check if C permute dimension matches GEMM + GEMM shape
// Check if C permute dimension matches GEMM + GEMM shape
const
index_t
c_g
=
device_arg
.
c_grid_desc_g_m_n_
.
GetLength
(
I0
);
// unpadded
const
index_t
c_g
=
device_arg
.
c_grid_desc_g_m_n_
.
GetLength
(
I0
);
// unpadded
const
index_t
b_g
=
device_arg
.
b_grid_desc_g_n_k_
.
GetLength
(
I0
);
const
index_t
c_m
=
kernel_arg
.
y_grid_desc_m_o_
.
GetLength
(
I0
);
const
index_t
c_m
=
kernel_arg
.
y_grid_desc_m_o_
.
GetLength
(
I0
);
const
index_t
c_gemm1n
=
kernel_arg
.
y_grid_desc_m_o_
.
GetLength
(
I1
);
const
index_t
c_gemm1n
=
kernel_arg
.
y_grid_desc_m_o_
.
GetLength
(
I1
);
const
index_t
a_m
=
kernel_arg
.
a_grid_desc_ak0_m_ak1_
.
GetLength
(
I1
);
const
index_t
a_m
=
kernel_arg
.
a_grid_desc_ak0_m_ak1_
.
GetLength
(
I1
);
const
index_t
b1_gemm1n
=
kernel_arg
.
b1_grid_desc_bk0_n_bk1_
.
GetLength
(
I0
)
*
const
index_t
b1_gemm1n
=
kernel_arg
.
b1_grid_desc_bk0_n_bk1_
.
GetLength
(
I0
)
*
kernel_arg
.
b1_grid_desc_bk0_n_bk1_
.
GetLength
(
I2
);
kernel_arg
.
b1_grid_desc_bk0_n_bk1_
.
GetLength
(
I2
);
if
(
!
(
c_g
==
device_arg
.
batch_count_
&&
c_m
==
a_m
&&
c_gemm1n
==
b1_gemm1n
))
if
(
!
(
c_g
==
device_arg
.
batch_count_
&&
c_m
==
a_m
&&
c_gemm1n
==
b1_gemm1n
&&
c_g
%
b_g
==
0
&&
c_g
/
b_g
==
arg
.
h_ratio_
))
{
{
return
false
;
return
false
;
}
}
...
@@ -1471,6 +1527,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
...
@@ -1471,6 +1527,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
B1ElementwiseOperation
b1_element_op
,
B1ElementwiseOperation
b1_element_op
,
CElementwiseOperation
c_element_op
,
CElementwiseOperation
c_element_op
,
float
p_drop
,
float
p_drop
,
index_t
h_ratio
,
std
::
tuple
<
unsigned
long
long
,
unsigned
long
long
>
seeds
)
std
::
tuple
<
unsigned
long
long
,
unsigned
long
long
>
seeds
)
{
{
return
Argument
{
p_As
,
return
Argument
{
p_As
,
...
@@ -1495,6 +1552,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
...
@@ -1495,6 +1552,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
b1_element_op
,
b1_element_op
,
c_element_op
,
c_element_op
,
p_drop
,
p_drop
,
h_ratio
,
seeds
};
seeds
};
}
}
...
@@ -1525,6 +1583,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
...
@@ -1525,6 +1583,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
B1ElementwiseOperation
b1_element_op
,
B1ElementwiseOperation
b1_element_op
,
CElementwiseOperation
c_element_op
,
CElementwiseOperation
c_element_op
,
float
p_drop
,
float
p_drop
,
index_t
h_ratio
,
std
::
tuple
<
unsigned
long
long
,
unsigned
long
long
>
seeds
)
// override
std
::
tuple
<
unsigned
long
long
,
unsigned
long
long
>
seeds
)
// override
{
{
return
std
::
make_unique
<
Argument
>
(
p_As
,
return
std
::
make_unique
<
Argument
>
(
p_As
,
...
@@ -1549,6 +1608,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
...
@@ -1549,6 +1608,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
b1_element_op
,
b1_element_op
,
c_element_op
,
c_element_op
,
p_drop
,
p_drop
,
h_ratio
,
seeds
);
seeds
);
}
}
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_light_v1.hpp
View file @
104aeabc
...
@@ -1440,10 +1440,12 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
...
@@ -1440,10 +1440,12 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
const
CElementwiseOperation
&
c_element_op
,
const
CElementwiseOperation
&
c_element_op
,
const
QGridDesc_K0_M_K1
&
q_grid_desc_k0_m_k1
,
const
QGridDesc_K0_M_K1
&
q_grid_desc_k0_m_k1
,
const
KGridDesc_K0_N_K1
&
k_grid_desc_k0_n_k1
,
const
KGridDesc_K0_N_K1
&
k_grid_desc_k0_n_k1
,
const
KGridDesc_K0_N_K1
&
kgrad_grid_desc_k0_n_k1
,
const
D0GridDescriptor_M0_N0_M1_M2_N1_M3
&
d0_grid_desc_m0_n0_m1_m2_n1_m3
,
const
D0GridDescriptor_M0_N0_M1_M2_N1_M3
&
d0_grid_desc_m0_n0_m1_m2_n1_m3
,
const
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3
&
const
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3
&
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
,
const
VGridDesc_O0_N_O1
&
v_grid_desc_o0_n_o1
,
const
VGridDesc_O0_N_O1
&
v_grid_desc_o0_n_o1
,
const
VGridDesc_O0_N_O1
&
vgrad_grid_desc_o0_n_o1
,
const
LSEGridDesc_M
&
lse_grid_desc_m
,
const
LSEGridDesc_M
&
lse_grid_desc_m
,
const
YGradGridDesc_O0_M_O1
&
ygrad_grid_desc_o0_m_o1
,
const
YGradGridDesc_O0_M_O1
&
ygrad_grid_desc_o0_m_o1
,
const
Block2CTileMap
&
block_2_ctile_map
,
const
Block2CTileMap
&
block_2_ctile_map
,
...
@@ -1474,11 +1476,11 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
...
@@ -1474,11 +1476,11 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
const
auto
ygrad_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
const
auto
ygrad_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_ygrad_grid
,
ygrad_grid_desc_o0_m_o1
.
GetElementSpaceSize
());
p_ygrad_grid
,
ygrad_grid_desc_o0_m_o1
.
GetElementSpaceSize
());
auto
vgrad_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
auto
vgrad_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_vgrad_grid
,
v_grid_desc_o0_n_o1
.
GetElementSpaceSize
());
p_vgrad_grid
,
v
grad
_grid_desc_o0_n_o1
.
GetElementSpaceSize
());
auto
qgrad_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
auto
qgrad_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_qgrad_grid
,
q_grid_desc_k0_m_k1
.
GetElementSpaceSize
());
p_qgrad_grid
,
q_grid_desc_k0_m_k1
.
GetElementSpaceSize
());
auto
kgrad_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
auto
kgrad_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_kgrad_grid
,
k_grid_desc_k0_n_k1
.
GetElementSpaceSize
());
p_kgrad_grid
,
k
grad
_grid_desc_k0_n_k1
.
GetElementSpaceSize
());
// divide block work by [N, K]
// divide block work by [N, K]
const
auto
block_work_idx
=
const
auto
block_work_idx
=
...
@@ -1628,7 +1630,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
...
@@ -1628,7 +1630,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
// dV: transform input and output tensor descriptors
// dV: transform input and output tensor descriptors
auto
vgrad_grid_desc_nblock_nperblock_oblock_operblock
=
auto
vgrad_grid_desc_nblock_nperblock_oblock_operblock
=
MakeVGradGridDesc_NBlock_NPerBlock_OBlock_OPerBlock
(
v_grid_desc_o0_n_o1
);
MakeVGradGridDesc_NBlock_NPerBlock_OBlock_OPerBlock
(
v
grad
_grid_desc_o0_n_o1
);
// dK: A matrix blockwise copy
// dK: A matrix blockwise copy
auto
kgrad_gemm_tile_sgrad_blockwise_copy
=
auto
kgrad_gemm_tile_sgrad_blockwise_copy
=
...
@@ -1657,7 +1659,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
...
@@ -1657,7 +1659,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
// dK: transform input and output tensor descriptors
// dK: transform input and output tensor descriptors
auto
kgrad_grid_desc_nblock_nperblock_oblock_operblock
=
auto
kgrad_grid_desc_nblock_nperblock_oblock_operblock
=
MakeKGradGridDesc_NBlock_NPerBlock_OBlock_OPerBlock
(
k_grid_desc_k0_n_k1
);
MakeKGradGridDesc_NBlock_NPerBlock_OBlock_OPerBlock
(
k
grad
_grid_desc_k0_n_k1
);
//
//
// set up dQ Gemm (type 3 crr)
// set up dQ Gemm (type 3 crr)
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_light_v2.hpp
View file @
104aeabc
...
@@ -1531,10 +1531,12 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
...
@@ -1531,10 +1531,12 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
const
CElementwiseOperation
&
c_element_op
,
const
CElementwiseOperation
&
c_element_op
,
const
QGridDesc_K0_M_K1
&
q_grid_desc_k0_m_k1
,
const
QGridDesc_K0_M_K1
&
q_grid_desc_k0_m_k1
,
const
KGridDesc_K0_N_K1
&
k_grid_desc_k0_n_k1
,
const
KGridDesc_K0_N_K1
&
k_grid_desc_k0_n_k1
,
const
KGridDesc_K0_N_K1
&
kgrad_grid_desc_k0_n_k1
,
const
D0GridDescriptor_M0_N0_M1_M2_N1_M3
&
d0_grid_desc_m0_n0_m1_m2_n1_m3
,
const
D0GridDescriptor_M0_N0_M1_M2_N1_M3
&
d0_grid_desc_m0_n0_m1_m2_n1_m3
,
const
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3
&
const
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3
&
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
,
const
VGridDesc_O0_N_O1
&
v_grid_desc_o0_n_o1
,
const
VGridDesc_O0_N_O1
&
v_grid_desc_o0_n_o1
,
const
VGridDesc_O0_N_O1
&
vgrad_grid_desc_o0_n_o1
,
const
LSEGridDesc_M
&
lse_grid_desc_m
,
const
LSEGridDesc_M
&
lse_grid_desc_m
,
const
YGradGridDesc_M0_O_M1
&
ygrad_grid_desc_m0_o_m1
,
const
YGradGridDesc_M0_O_M1
&
ygrad_grid_desc_m0_o_m1
,
const
Block2CTileMap
&
block_2_ctile_map
,
const
Block2CTileMap
&
block_2_ctile_map
,
...
@@ -1565,11 +1567,11 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
...
@@ -1565,11 +1567,11 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
const
auto
ygrad_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
const
auto
ygrad_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_ygrad_grid
,
ygrad_grid_desc_m0_o_m1
.
GetElementSpaceSize
());
p_ygrad_grid
,
ygrad_grid_desc_m0_o_m1
.
GetElementSpaceSize
());
auto
vgrad_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
auto
vgrad_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_vgrad_grid
,
v_grid_desc_o0_n_o1
.
GetElementSpaceSize
());
p_vgrad_grid
,
v
grad
_grid_desc_o0_n_o1
.
GetElementSpaceSize
());
auto
qgrad_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
auto
qgrad_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_qgrad_grid
,
q_grid_desc_k0_m_k1
.
GetElementSpaceSize
());
p_qgrad_grid
,
q_grid_desc_k0_m_k1
.
GetElementSpaceSize
());
auto
kgrad_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
auto
kgrad_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_kgrad_grid
,
k_grid_desc_k0_n_k1
.
GetElementSpaceSize
());
p_kgrad_grid
,
k
grad
_grid_desc_k0_n_k1
.
GetElementSpaceSize
());
// divide block work by [N, K]
// divide block work by [N, K]
const
auto
block_work_idx
=
const
auto
block_work_idx
=
...
@@ -1742,7 +1744,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
...
@@ -1742,7 +1744,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
// dV: transform input and output tensor descriptors
// dV: transform input and output tensor descriptors
auto
vgrad_grid_desc_nblock_nperblock_oblock_operblock
=
auto
vgrad_grid_desc_nblock_nperblock_oblock_operblock
=
MakeVGradGridDesc_NBlock_NPerBlock_OBlock_OPerBlock
(
v_grid_desc_o0_n_o1
);
MakeVGradGridDesc_NBlock_NPerBlock_OBlock_OPerBlock
(
v
grad
_grid_desc_o0_n_o1
);
// dK: transform input and output tensor descriptors
// dK: transform input and output tensor descriptors
const
auto
q_grid_desc_m0_k_m1
=
const
auto
q_grid_desc_m0_k_m1
=
...
@@ -1775,7 +1777,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
...
@@ -1775,7 +1777,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
// dK: transform input and output tensor descriptors
// dK: transform input and output tensor descriptors
auto
kgrad_grid_desc_nblock_nperblock_oblock_operblock
=
auto
kgrad_grid_desc_nblock_nperblock_oblock_operblock
=
MakeKGradGridDesc_NBlock_NPerBlock_OBlock_OPerBlock
(
k_grid_desc_k0_n_k1
);
MakeKGradGridDesc_NBlock_NPerBlock_OBlock_OPerBlock
(
k
grad
_grid_desc_k0_n_k1
);
//
//
// set up dQ Gemm (type 3 crr)
// set up dQ Gemm (type 3 crr)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment