Commit 6a2521ea authored by Jing Zhang's avatar Jing Zhang
Browse files

fixed splitk crush

parent af2c0166
...@@ -147,7 +147,7 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) ...@@ -147,7 +147,7 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl; std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl;
DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize()); DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize());
DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize()); DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n_permute.mDesc.GetElementSpaceSize());
DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpaceSize()); DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpaceSize());
// weight permute // weight permute
......
...@@ -389,7 +389,6 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -389,7 +389,6 @@ struct GridwiseGemm_xdl_cshuffle_v3
} }
else else
{ {
// Weight Tile Permute
#ifndef WEIGHT_PERMUTE #ifndef WEIGHT_PERMUTE
// not pad N or K // not pad N or K
const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
...@@ -398,23 +397,27 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -398,23 +397,27 @@ struct GridwiseGemm_xdl_cshuffle_v3
make_pass_through_transform(N)), make_pass_through_transform(N)),
make_tuple(Sequence<1>{}, Sequence<0>{}), make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{})); make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return b_grid_desc_bk0_n_bk1;
#else #else
// Weight Tile Permute
constexpr index_t BK01 = KPerBlock / BK1Value; constexpr index_t BK01 = KPerBlock / BK1Value;
const index_t BK00 = BK0 / BK01; const index_t BK0_ = StrideB / BK1Value;
const index_t BK00 = BK0_ / BK01;
const auto b_grid_desc_bk00_n_bk01_bk1 = const auto b_grid_desc_bk00_n_bk01_bk1_permute =
make_naive_tensor_descriptor_packed(make_tuple(BK00, N, BK01, BK1Value)); make_naive_tensor_descriptor_packed(make_tuple(BK00, N, BK01, BK1Value));
const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( const auto b_grid_desc_bk0_n_bk1_permute = transform_tensor_descriptor(
b_grid_desc_bk00_n_bk01_bk1, b_grid_desc_bk00_n_bk01_bk1_permute,
make_tuple(make_merge_transform(make_tuple(BK00, BK01)), make_tuple(make_merge_transform(make_tuple(BK00, BK01)),
make_pass_through_transform(make_tuple(N)), make_pass_through_transform(make_tuple(N)),
make_pass_through_transform(BK1Value)), make_pass_through_transform(BK1Value)),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}, Sequence<3>{}), make_tuple(Sequence<0, 2>{}, Sequence<1>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
#endif
return b_grid_desc_bk0_n_bk1; return b_grid_desc_bk0_n_bk1_permute;
#endif
} }
} }
......
...@@ -14,9 +14,7 @@ namespace ck { ...@@ -14,9 +14,7 @@ namespace ck {
inline __device__ int amd_assembly_and_or_b32(int a, int b, int d) inline __device__ int amd_assembly_and_or_b32(int a, int b, int d)
{ {
int c; int c;
asm volatile("v_and_or_b32 %0, %1, %2, %3" asm volatile("v_and_or_b32 %0, %1, %2, %3" : "=v"(c) : "v"(a), "v"(b), "v"(d));
: "=v"(c)
: "v"(a), "v"(b), "v"(d));
return c; return c;
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment