Commit 20eb3b68 authored by Jing Zhang's avatar Jing Zhang
Browse files

fixed

parent 2af8f32a
...@@ -581,7 +581,7 @@ if(NOT GPU_ARCHS AND USER_GPU_TARGETS) ...@@ -581,7 +581,7 @@ if(NOT GPU_ARCHS AND USER_GPU_TARGETS)
) )
add_subdirectory(example) add_subdirectory(example)
if(BUILD_TESTING) if(BUILD_TESTING)
add_subdirectory(test) add_subdirectory(test)
endif() endif()
endif() endif()
......
...@@ -165,7 +165,11 @@ struct StaticTensorTupleOfVectorBuffer ...@@ -165,7 +165,11 @@ struct StaticTensorTupleOfVectorBuffer
// Get X // Get X
// Idx is for S, not X. Idx should be aligned with X // Idx is for S, not X. Idx should be aligned with X
template <typename X, typename Idx> template <typename X,
typename Idx,
typename enable_if<(has_same_scalar_type<S, X>::value || !is_native_type<S>()) &&
is_known_at_compile_time<Idx>::value && Idx::Size() == ndim_,
bool>::type = false>
__host__ __device__ constexpr X GetAsType(Idx) const __host__ __device__ constexpr X GetAsType(Idx) const
{ {
constexpr auto coord = make_tensor_coordinate(desc_, to_multi_index(Idx{})); constexpr auto coord = make_tensor_coordinate(desc_, to_multi_index(Idx{}));
...@@ -195,7 +199,11 @@ struct StaticTensorTupleOfVectorBuffer ...@@ -195,7 +199,11 @@ struct StaticTensorTupleOfVectorBuffer
// Set X // Set X
// Idx is for S, not X. Idx should be aligned with X // Idx is for S, not X. Idx should be aligned with X
template <typename X, typename Idx> template <typename X,
typename Idx,
typename enable_if<(has_same_scalar_type<S, X>::value || !is_native_type<S>()) &&
is_known_at_compile_time<Idx>::value && Idx::Size() == ndim_,
bool>::type = false>
__host__ __device__ constexpr void SetAsType(Idx, X x) __host__ __device__ constexpr void SetAsType(Idx, X x)
{ {
constexpr auto coord = make_tensor_coordinate(desc_, to_multi_index(Idx{})); constexpr auto coord = make_tensor_coordinate(desc_, to_multi_index(Idx{}));
......
...@@ -407,7 +407,8 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -407,7 +407,8 @@ struct GridwiseGemm_xdl_cshuffle_v3
} }
else else
{ {
// Weight Tile Permute // Pre-shuffled Weight
// BGlobal[K / KPerBlock, N, KPerBlock / K1, K1] -> BTile[K / K1, N, K1]
constexpr index_t BK01 = KPerBlock / BK1Value; constexpr index_t BK01 = KPerBlock / BK1Value;
// const index_t BK00 = BK0 / BK01; // const index_t BK00 = BK0 / BK01;
const index_t BK0_ = StrideB / BK1Value; const index_t BK0_ = StrideB / BK1Value;
......
...@@ -230,13 +230,7 @@ bool profile_gemm_universal_impl(int do_verification, ...@@ -230,13 +230,7 @@ bool profile_gemm_universal_impl(int do_verification,
} }
else else
{ {
for(int i = 0; i < N; i++) b_k_n_permute(i * K + j) = b_k_n(i * K + j);
{
for(int j = 0; j < K; j++)
{
b_k_n_permute(i * K + j) = b_k_n(i * K + j);
}
}
} }
b_device_buf.ToDevice(b_k_n_permute.mData.data()); b_device_buf.ToDevice(b_k_n_permute.mData.data());
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment