Commit 20eb3b68 authored by Jing Zhang's avatar Jing Zhang
Browse files

fixed

parent 2af8f32a
......@@ -165,7 +165,11 @@ struct StaticTensorTupleOfVectorBuffer
// Get X
// Idx is for S, not X. Idx should be aligned with X
template <typename X, typename Idx>
template <typename X,
typename Idx,
typename enable_if<(has_same_scalar_type<S, X>::value || !is_native_type<S>()) &&
is_known_at_compile_time<Idx>::value && Idx::Size() == ndim_,
bool>::type = false>
__host__ __device__ constexpr X GetAsType(Idx) const
{
constexpr auto coord = make_tensor_coordinate(desc_, to_multi_index(Idx{}));
......@@ -195,7 +199,11 @@ struct StaticTensorTupleOfVectorBuffer
// Set X
// Idx is for S, not X. Idx should be aligned with X
template <typename X, typename Idx>
template <typename X,
typename Idx,
typename enable_if<(has_same_scalar_type<S, X>::value || !is_native_type<S>()) &&
is_known_at_compile_time<Idx>::value && Idx::Size() == ndim_,
bool>::type = false>
__host__ __device__ constexpr void SetAsType(Idx, X x)
{
constexpr auto coord = make_tensor_coordinate(desc_, to_multi_index(Idx{}));
......
......@@ -407,7 +407,8 @@ struct GridwiseGemm_xdl_cshuffle_v3
}
else
{
// Weight Tile Permute
// Pre-shuffled Weight
// BGlobal[K / KPerBlock, N, KPerBlock / K1, K1] -> BTile[K / K1, N, K1]
constexpr index_t BK01 = KPerBlock / BK1Value;
// const index_t BK00 = BK0 / BK01;
const index_t BK0_ = StrideB / BK1Value;
......
......@@ -229,15 +229,9 @@ bool profile_gemm_universal_impl(int do_verification,
}
}
else
{
for(int i = 0; i < N; i++)
{
for(int j = 0; j < K; j++)
{
b_k_n_permute(i * K + j) = b_k_n(i * K + j);
}
}
}
b_device_buf.ToDevice(b_k_n_permute.mData.data());
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment