Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
e327363f
Commit
e327363f
authored
Feb 06, 2023
by
fsx950223
Browse files
fix bugss
parent
0fe4fb38
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
62 additions
and
82 deletions
+62
-82
example/32_batched_gemm_scale_softmax_gemm/grouped_multihead_attention_backward_fp16.cpp
...oftmax_gemm/grouped_multihead_attention_backward_fp16.cpp
+35
-4
include/ck/tensor_operation/gpu/device/impl/device_grouped_multihead_attention_backward_xdl_cshuffle_v2.hpp
..._grouped_multihead_attention_backward_xdl_cshuffle_v2.hpp
+27
-78
No files found.
example/32_batched_gemm_scale_softmax_gemm/grouped_multihead_attention_backward_fp16.cpp
View file @
e327363f
...
...
@@ -306,17 +306,18 @@ void run_attention_fwd_host(const TensorQ& q_g_m_k,
ref_softmax_invoker
.
Run
(
ref_softmax_argument
);
// P_dropped
auto
ref_dropout
=
ReferenceDropoutInstance
{};
auto
ref_dropout_invoker
=
ref_dropout
.
MakeInvoker
();
auto
ref_dropout_argment
=
ref_dropout
.
MakeArgument
(
z_g_m_n
,
p_g_m_n
,
p_drop_g_m_n
,
p_dropout_in_16bits
,
rp_dropout
);
ref_dropout_invoker
.
Run
(
ref_dropout_argment
);
// Y = P * V
// Y = P
_dropout
* V
auto
ref_gemm1
=
ReferenceGemm1Instance
{};
auto
ref_gemm1_invoker
=
ref_gemm1
.
MakeInvoker
();
auto
ref_gemm1_argument
=
ref_gemm1
.
MakeArgument
(
p_g_m_n
,
v_g_n_o
,
y_g_m_o
,
PassThrough
{},
PassThrough
{},
PassThrough
{});
p_
drop_
g_m_n
,
v_g_n_o
,
y_g_m_o
,
PassThrough
{},
PassThrough
{},
PassThrough
{});
ref_gemm1_invoker
.
Run
(
ref_gemm1_argument
);
}
...
...
@@ -425,8 +426,8 @@ int run(int argc, char* argv[])
{
int
M
=
128
*
(
rand
()
%
4
+
1
);
int
N
=
128
*
(
rand
()
%
4
+
1
);
int
K
=
64
;
int
O
=
64
;
int
K
=
128
;
int
O
=
128
;
int
G0
=
rand
()
%
3
+
1
;
int
G1
=
rand
()
%
2
+
1
;
std
::
vector
<
ck
::
index_t
>
q_gs_ms_ks_lengths
{
G0
,
G1
,
M
,
K
};
...
...
@@ -720,6 +721,36 @@ int run(int argc, char* argv[])
kgrad_tensors_device
[
i
]
->
SetZero
();
vgrad_tensors_device
[
i
]
->
SetZero
();
}
// p_z = std::vector<void*>(p_z.size(), nullptr);
// argument =
// gemm.MakeArgument(p_q,
// p_k,
// p_z,
// p_v,
// p_y,
// p_lse,
// p_ygrad,
// p_qgrad,
// p_kgrad,
// p_vgrad,
// {}, // std::array<void*, 1> p_acc0_biases;
// {}, // std::array<void*, 1> p_acc1_biases;
// problem_descs,
// QKVElementOp{},
// QKVElementOp{},
// Scale{alpha},
// QKVElementOp{},
// YElementOp{},
// p_drop,
// std::tuple<unsigned long long, unsigned long long>(seed, offset));
// DeviceMem problem_desc_workspace(gemm.GetWorkSpaceSize(&argument));
// gemm.SetWorkSpacePointer(&argument, problem_desc_workspace.GetDeviceBuffer());
// if(!gemm.IsSupportedArgument(argument))
// {
// std::cout << gemm.GetTypeString() << " does not support this problem" << std::endl;
// return 0;
// }
invoker
.
Run
(
argument
,
StreamConfig
{
nullptr
,
false
});
for
(
std
::
size_t
i
=
0
;
i
<
group_count
;
i
++
)
{
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_multihead_attention_backward_xdl_cshuffle_v2.hpp
View file @
e327363f
...
...
@@ -98,7 +98,7 @@ __global__ void
unsigned
short
*
z_matrix_ptr
=
(
arg_ptr
[
group_id
].
p_z_grid_
==
nullptr
?
nullptr
:
arg_ptr
[
group_id
].
p_z_grid_
+
z_batch_offset
);
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
arg_ptr
[
group_id
].
p_a_grid_
+
a_batch_offset
,
arg_ptr
[
group_id
].
p_b_grid_
+
b_batch_offset
,
...
...
@@ -379,56 +379,6 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
}
// V in Gemm B position
static
auto
MakeVGridDescriptor_O0_N_O1
(
const
std
::
vector
<
index_t
>&
v_gs_os_ns_lengths_vec
,
const
std
::
vector
<
index_t
>&
v_gs_os_ns_strides_vec
)
{
// v_gs_os_ns -> vgrad_gs_ns_os. O dims last because output is row-major.
// Here directly rearrange lengths/strides before constructing tensor descriptor to reduce
// transformation overhead
// TODO: This will be much easier when inputs are Gs, Ms, Ns, Os. So there's no need to
// extract subsequence and shuffle them.
const
index_t
num_dims
=
NumDimG
+
NumDimN
+
NumDimO
;
// 0, 1, .. NumDimG - 1
std
::
vector
<
index_t
>
gs_ids
(
NumDimG
);
std
::
iota
(
gs_ids
.
begin
(),
gs_ids
.
end
(),
0
);
// NumDimG, NumDimG + 1, ... NumDimG + NumDimO - 1
std
::
vector
<
index_t
>
os_ids
(
NumDimO
);
std
::
iota
(
os_ids
.
begin
(),
os_ids
.
end
(),
NumDimG
);
// NumDimG + NumDimO, NumDimG + NumDimO + 1, ... NumDimG + NumDimO + NumDimN - 1
std
::
vector
<
index_t
>
ns_ids
(
NumDimN
);
std
::
iota
(
ns_ids
.
begin
(),
ns_ids
.
end
(),
NumDimG
+
NumDimO
);
std
::
vector
<
index_t
>
ids_old2new
;
ids_old2new
.
insert
(
ids_old2new
.
end
(),
gs_ids
.
begin
(),
gs_ids
.
end
());
ids_old2new
.
insert
(
ids_old2new
.
end
(),
ns_ids
.
begin
(),
ns_ids
.
end
());
ids_old2new
.
insert
(
ids_old2new
.
end
(),
os_ids
.
begin
(),
os_ids
.
end
());
std
::
vector
<
index_t
>
v_gs_ns_os_lengths_vec
(
num_dims
),
v_gs_ns_os_strides_vec
(
num_dims
);
for
(
int
i
=
0
;
i
<
num_dims
;
i
++
)
{
index_t
id_new
=
ids_old2new
[
i
];
v_gs_ns_os_lengths_vec
[
i
]
=
v_gs_os_ns_lengths_vec
[
id_new
];
v_gs_ns_os_strides_vec
[
i
]
=
v_gs_os_ns_strides_vec
[
id_new
];
}
const
auto
v_grid_desc_nraw_oraw
=
MakeGridDescriptorPair
<
NumDimG
,
NumDimN
,
NumDimO
,
TensorSpecialization
::
Default
>
(
v_gs_ns_os_lengths_vec
,
v_gs_ns_os_strides_vec
)
.
second
;
const
auto
v_grid_desc_n_o
=
PadTensorDescriptor
(
v_grid_desc_nraw_oraw
,
make_tuple
(
NPerBlock
,
Gemm1NPerBlock
),
Sequence
<
padder
.
PadN
,
padder
.
PadO
>
{});
// N_O to O0_N_O1; to refactor
return
Transform
::
MakeB0GridDescriptor_BK0_N_BK1
(
v_grid_desc_n_o
,
Number
<
V_O1
>
{});
}
//
// dS_i_j = P_i_j .* (dP_i_j - dY_i dot Y_i)
//
...
...
@@ -488,7 +438,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2
using
AGridDesc_AK0_M_AK1
=
decltype
(
MakeAGridDescriptor_AK0_M_AK1
({},
{}));
using
BGridDesc_BK0_N_BK1
=
decltype
(
MakeBGridDescriptor_BK0_N_BK1
({},
{}));
using
B1GridDesc_BK0_N_BK1
=
decltype
(
MakeBGridDescriptor_BK0_N_BK1
({},
{}));
using
B1GridDesc_BK0_N_BK1
=
decltype
(
MakeB
1
GridDescriptor_BK0_N_BK1
({},
{}));
using
YGridDesc_M_O
=
decltype
(
Transform
::
MakeCGridDescriptor_M_N
({},
{}));
using
LSEGridDesc_M
=
decltype
(
MakeLSEGridDescriptor_M
(
1
));
using
AGridDesc_G_M_K
=
decltype
(
Transform
::
MakeAGridDescriptor_G_M_K
({},
{}));
...
...
@@ -769,7 +719,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2
problem_desc
.
b_gs_ns_ks_lengths
,
problem_desc
.
b_gs_ns_ks_strides
);
const
auto
z_grid_desc_m_n
=
DeviceOp
::
MakeZGridDescriptor_M_N
(
problem_desc
.
z_gs_ms_ns_lengths
,
problem_desc
.
z_gs_ms_ns_strides
);
const
auto
b1_grid_desc_bk0_n_bk1
=
DeviceOp
::
Make
V
GridDescriptor_
O
0_N_
O
1
(
const
auto
b1_grid_desc_bk0_n_bk1
=
DeviceOp
::
Make
B1
GridDescriptor_
BK
0_N_
BK
1
(
problem_desc
.
b1_gs_gemm1ns_gemm1ks_lengths
,
problem_desc
.
b1_gs_gemm1ns_gemm1ks_strides
);
const
auto
y_grid_desc_m_o
=
Transform
::
MakeCGridDescriptor_M_N
(
...
...
@@ -927,16 +877,16 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2
throw
std
::
runtime_error
(
"wrong! unsupported argument"
);
}
//
bool all_has_main_k_block_loop = true;
//
bool some_has_main_k_block_loop = false;
//
for(std::size_t i = 0; i < arg.group_count_; i++)
//
{
//
const auto K =
arg.group_kernel_args_[i].a_grid_desc_ak0_m_ak1_.GetLength(I0) *
//
arg.group_kernel_args_[i].a_grid_desc_ak0_m_ak1_.GetLength(I2);
//
const bool y = GridwiseGemm::CalculateHasMainKBlockLoop(K);
//
all_has_main_k_block_loop &= y;
//
some_has_main_k_block_loop |= y;
//
}
bool
all_has_main_k_block_loop
=
true
;
bool
some_has_main_k_block_loop
=
false
;
for
(
std
::
size_t
i
=
0
;
i
<
arg
.
group_count_
;
i
++
)
{
const
auto
K
=
arg
.
group_kernel_args_
[
i
].
a_grid_desc_ak0_m_ak1_
.
GetLength
(
I0
)
*
arg
.
group_kernel_args_
[
i
].
a_grid_desc_ak0_m_ak1_
.
GetLength
(
I2
);
const
bool
y
=
GridwiseGemm
::
CalculateHasMainKBlockLoop
(
K
);
all_has_main_k_block_loop
&=
y
;
some_has_main_k_block_loop
|=
y
;
}
hipGetErrorString
(
hipMemcpy
(
arg
.
p_workspace_
,
arg
.
group_kernel_args_
.
data
(),
...
...
@@ -976,19 +926,19 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2
// Gemm1_K is split into Gemm1_K0/K1 where K1 is known at compile time, so we only need
// to concern Gemm0's loop
//
if(all_has_main_k_block_loop)
//
{
//
ave_time = launch_kernel(integral_constant<bool, true>{});
//
}
//
else if(!some_has_main_k_block_loop)
//
{
ave_time
=
launch_kernel
(
integral_constant
<
bool
,
false
>
{});
//
}
//
else
//
{
//
throw std::runtime_error("wrong! all gemm problems have to simultaneously meet "
//
"has_main_k_block_loop or no_main_k_block_loop");
//
}
if
(
all_has_main_k_block_loop
)
{
ave_time
=
launch_kernel
(
integral_constant
<
bool
,
true
>
{});
}
else
if
(
!
some_has_main_k_block_loop
)
{
ave_time
=
launch_kernel
(
integral_constant
<
bool
,
false
>
{});
}
else
{
throw
std
::
runtime_error
(
"wrong! all gemm problems have to simultaneously meet "
"has_main_k_block_loop or no_main_k_block_loop"
);
}
return
ave_time
;
}
...
...
@@ -1023,8 +973,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2
const
index_t
c_m
=
kernel_arg
.
y_grid_desc_m_o_
.
GetLength
(
I0
);
const
index_t
c_gemm1n
=
kernel_arg
.
y_grid_desc_m_o_
.
GetLength
(
I1
);
const
index_t
a_m
=
kernel_arg
.
a_grid_desc_ak0_m_ak1_
.
GetLength
(
I1
);
const
index_t
b1_gemm1n
=
kernel_arg
.
b1_grid_desc_bk0_n_bk1_
.
GetLength
(
I0
)
*
kernel_arg
.
b1_grid_desc_bk0_n_bk1_
.
GetLength
(
I2
);
const
index_t
b1_gemm1n
=
kernel_arg
.
b1_grid_desc_bk0_n_bk1_
.
GetLength
(
I1
);
if
(
!
(
c_g
==
device_arg
.
batch_count_
&&
c_m
==
a_m
&&
c_gemm1n
==
b1_gemm1n
))
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment