Commit ad77ce8e authored by letaoqin's avatar letaoqin
Browse files

fix for passthrough element op

parent f820c621
......@@ -544,10 +544,14 @@ struct GridwiseGemmDlMultipleD_km_kn_mn
blockwise_gemm.CalculateCThreadOriginOnBlock_BM0_BM1_BN0_BN1(
get_thread_local_1d_id());
if constexpr(!is_same_v<CDEElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough>)
{
const auto ds_grid_buf = generate_tuple(
[&](auto i) {
return make_dynamic_buffer<AddressSpaceEnum::Global>(
p_ds_grid[i], ds_grid_desc_m0_m10_m11_n0_n10_n11[i].GetElementSpaceSize());
p_ds_grid[i],
ds_grid_desc_m0_m10_m11_n0_n10_n11[i].GetElementSpaceSize());
},
Number<NumDTensor>{});
......@@ -583,26 +587,20 @@ struct GridwiseGemmDlMultipleD_km_kn_mn
1,
false>(
ds_grid_desc_m0_m10_m11_n0_n10_n11[i],
make_multi_index(
im0,
make_multi_index(im0,
c_m10_m11_n10_n11_thread_origin_idx_on_block[I0],
c_m10_m11_n10_n11_thread_origin_idx_on_block[I1],
in0,
c_m10_m11_n10_n11_thread_origin_idx_on_block[I2],
c_m10_m11_n10_n11_thread_origin_idx_on_block[I3])); // register number
c_m10_m11_n10_n11_thread_origin_idx_on_block[I3]));
},
Number<NumDTensor>{});
static_for<0, c_m10_m11_n10_n11_thread_tensor_lengths[I0], 1>{}([&](auto m10) {
static_for<0, c_m10_m11_n10_n11_thread_tensor_lengths[I1], 1>{}([&](auto m11) {
static_for<0, c_m10_m11_n10_n11_thread_tensor_lengths[I2], 1>{}([&](auto n10) {
ignore = m10;
ignore = m11;
ignore = n10;
ignore = ds_thread_buf;
ignore = ds_threadwise_copy;
ignore = ds_grid_buf;
static_for<0, c_m10_m11_n10_n11_thread_tensor_lengths[I2], 1>{}(
[&](auto n10) {
// load d matrix data
static_for<0, NumDTensor, 1>{}([&](auto i) {
ds_threadwise_copy(i).Run(ds_grid_desc_m0_m10_m11_n0_n10_n11[i],
ds_grid_buf[i],
......@@ -610,7 +608,7 @@ struct GridwiseGemmDlMultipleD_km_kn_mn
make_tuple(I0, I0, I0, I0, I0, I0),
ds_thread_buf(i));
});
// cal element op
static_for<0, c_m10_m11_n10_n11_thread_tensor_lengths[I3], 1>{}(
[&](auto i) {
// get reference to src data
......@@ -627,7 +625,9 @@ struct GridwiseGemmDlMultipleD_km_kn_mn
make_tuple(0, m10, m11, 0, n10, i));
auto dst_data_refs = generate_tie(
// return type should be lvalue
[&](auto) -> auto& { return c_thread_buf(Number<c_offset>{}); },
[&](auto) -> auto& {
return c_thread_buf(Number<c_offset>{});
},
Number<2>{});
unpack2(cde_element_op, dst_data_refs, src_data_refs);
......@@ -653,6 +653,7 @@ struct GridwiseGemmDlMultipleD_km_kn_mn
0, 1, -c_m10_m11_n10_n11_thread_tensor_lengths[I1], 0, 0, 0));
});
});
}
ThreadwiseTensorSliceTransfer_v1r3<
FloatAcc,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment