Commit f820c621 authored by letaoqin's avatar letaoqin
Browse files

fix spill registers

parent 952f02ff
......@@ -62,7 +62,7 @@ bool run_grouped_conv_fwd_dl(bool do_verification,
case 1:
in.GenerateTensorValue(GeneratorTensor_2<InDataType>{-2, 3});
wei.GenerateTensorValue(GeneratorTensor_2<WeiDataType>{-2, 3});
bias.GenerateTensorValue(GeneratorTensor_2<DDataType>{-1, 1});
bias.GenerateTensorValue(GeneratorTensor_2<DDataType>{-2, 3});
break;
case 2:
in.GenerateTensorValue(GeneratorTensor_3<InDataType>{0.0, 1.0});
......
......@@ -11,7 +11,7 @@ bool run_convnd_fwd_dl_example(int argc, char* argv[])
bool time_kernel = false;
ck::utils::conv::ConvParam conv_param{
2, 1, 128, 256, 192, {3, 3}, {14, 14}, {2, 2}, {1, 1}, {1, 1}, {1, 1}};
2, 1, 128, 128, 192, {3, 3}, {14, 14}, {1, 1}, {1, 1}, {0, 0}, {0, 0}};
if(argc == 1)
{
......
......@@ -557,7 +557,7 @@ struct GridwiseGemmDlMultipleD_km_kn_mn
return StaticBuffer<AddressSpaceEnum::Vgpr,
DDataType,
c_thread_desc_m0_m10_m11_n0_n10_n11.GetElementSpaceSize(),
c_m10_m11_n10_n11_thread_tensor_lengths[I3],
true>{};
},
Number<NumDTensor>{});
......@@ -572,10 +572,10 @@ struct GridwiseGemmDlMultipleD_km_kn_mn
decltype(ds_grid_desc_m0_m10_m11_n0_n10_n11[i]),
decltype(c_thread_desc_m0_m10_m11_n0_n10_n11),
Sequence<I1,
Number<c_m10_m11_n10_n11_thread_tensor_lengths[I0]>{},
Number<c_m10_m11_n10_n11_thread_tensor_lengths[I1]>{},
I1,
Number<c_m10_m11_n10_n11_thread_tensor_lengths[I2]>{},
I1,
I1,
I1,
Number<c_m10_m11_n10_n11_thread_tensor_lengths[I3]>{}>,
CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim,
......@@ -593,30 +593,66 @@ struct GridwiseGemmDlMultipleD_km_kn_mn
},
Number<NumDTensor>{});
static_for<0, NumDTensor, 1>{}([&](auto i) {
ds_threadwise_copy(i).Run(ds_grid_desc_m0_m10_m11_n0_n10_n11[i],
ds_grid_buf[i],
c_thread_desc_m0_m10_m11_n0_n10_n11,
make_tuple(I0, I0, I0, I0, I0, I0),
ds_thread_buf(i));
});
static_for<0, c_thread_desc_m0_m10_m11_n0_n10_n11.GetElementSpaceSize(), 1>{}(
[&](auto i) {
// get reference to src data
const auto src_data_refs = generate_tie(
// return type should be lvalue
[&](auto iSrc) -> const auto& { return ds_thread_buf[iSrc][i]; },
Number<NumDTensor>{});
// get reference to dst data
auto dst_data_refs = generate_tie(
// return type should be lvalue
[&](auto) -> auto& { return c_thread_buf(i); },
Number<2>{});
unpack2(cde_element_op, dst_data_refs, src_data_refs);
static_for<0, c_m10_m11_n10_n11_thread_tensor_lengths[I0], 1>{}([&](auto m10) {
static_for<0, c_m10_m11_n10_n11_thread_tensor_lengths[I1], 1>{}([&](auto m11) {
static_for<0, c_m10_m11_n10_n11_thread_tensor_lengths[I2], 1>{}([&](auto n10) {
ignore = m10;
ignore = m11;
ignore = n10;
ignore = ds_thread_buf;
ignore = ds_threadwise_copy;
ignore = ds_grid_buf;
static_for<0, NumDTensor, 1>{}([&](auto i) {
ds_threadwise_copy(i).Run(ds_grid_desc_m0_m10_m11_n0_n10_n11[i],
ds_grid_buf[i],
c_thread_desc_m0_m10_m11_n0_n10_n11,
make_tuple(I0, I0, I0, I0, I0, I0),
ds_thread_buf(i));
});
static_for<0, c_m10_m11_n10_n11_thread_tensor_lengths[I3], 1>{}(
[&](auto i) {
// get reference to src data
const auto src_data_refs = generate_tie(
// return type should be lvalue
[&](auto iSrc) -> const auto& {
return ds_thread_buf[iSrc][i];
},
Number<NumDTensor>{});
// get reference to dst data
constexpr index_t c_offset =
c_thread_desc_m0_m10_m11_n0_n10_n11.CalculateOffset(
make_tuple(0, m10, m11, 0, n10, i));
auto dst_data_refs = generate_tie(
// return type should be lvalue
[&](auto) -> auto& { return c_thread_buf(Number<c_offset>{}); },
Number<2>{});
unpack2(cde_element_op, dst_data_refs, src_data_refs);
});
static_for<0, NumDTensor, 1>{}([&](auto i) {
ds_threadwise_copy(i).MoveSrcSliceWindow(
ds_grid_desc_m0_m10_m11_n0_n10_n11[i],
make_multi_index(0, 0, 0, 0, 1, 0));
});
});
static_for<0, NumDTensor, 1>{}([&](auto i) {
ds_threadwise_copy(i).MoveSrcSliceWindow(
ds_grid_desc_m0_m10_m11_n0_n10_n11[i],
make_multi_index(
0, 0, 1, 0, -c_m10_m11_n10_n11_thread_tensor_lengths[I2], 0));
});
});
static_for<0, NumDTensor, 1>{}([&](auto i) {
ds_threadwise_copy(i).MoveSrcSliceWindow(
ds_grid_desc_m0_m10_m11_n0_n10_n11[i],
make_multi_index(
0, 1, -c_m10_m11_n10_n11_thread_tensor_lengths[I1], 0, 0, 0));
});
});
ThreadwiseTensorSliceTransfer_v1r3<
FloatAcc,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment