Commit 6a2521ea authored by Jing Zhang's avatar Jing Zhang
Browse files

fixed splitk crush

parent af2c0166
......@@ -147,7 +147,7 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl;
DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize());
DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize());
DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n_permute.mDesc.GetElementSpaceSize());
DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpaceSize());
// weight permute
......
......@@ -389,7 +389,6 @@ struct GridwiseGemm_xdl_cshuffle_v3
}
else
{
// Weight Tile Permute
#ifndef WEIGHT_PERMUTE
// not pad N or K
const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
......@@ -398,23 +397,27 @@ struct GridwiseGemm_xdl_cshuffle_v3
make_pass_through_transform(N)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return b_grid_desc_bk0_n_bk1;
#else
// Weight Tile Permute
constexpr index_t BK01 = KPerBlock / BK1Value;
const index_t BK00 = BK0 / BK01;
const index_t BK0_ = StrideB / BK1Value;
const index_t BK00 = BK0_ / BK01;
const auto b_grid_desc_bk00_n_bk01_bk1 =
const auto b_grid_desc_bk00_n_bk01_bk1_permute =
make_naive_tensor_descriptor_packed(make_tuple(BK00, N, BK01, BK1Value));
const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
b_grid_desc_bk00_n_bk01_bk1,
const auto b_grid_desc_bk0_n_bk1_permute = transform_tensor_descriptor(
b_grid_desc_bk00_n_bk01_bk1_permute,
make_tuple(make_merge_transform(make_tuple(BK00, BK01)),
make_pass_through_transform(make_tuple(N)),
make_pass_through_transform(BK1Value)),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
#endif
return b_grid_desc_bk0_n_bk1;
return b_grid_desc_bk0_n_bk1_permute;
#endif
}
}
......
......@@ -14,9 +14,7 @@ namespace ck {
inline __device__ int amd_assembly_and_or_b32(int a, int b, int d)
{
int c;
asm volatile("v_and_or_b32 %0, %1, %2, %3"
: "=v"(c)
: "v"(a), "v"(b), "v"(d));
asm volatile("v_and_or_b32 %0, %1, %2, %3" : "=v"(c) : "v"(a), "v"(b), "v"(d));
return c;
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment