Commit 65cfb2a1 authored by Jing Zhang's avatar Jing Zhang
Browse files

format

parent 398f8851
...@@ -66,7 +66,6 @@ __device__ inline half2_t pki4_to_half2(pk_i4_t q) ...@@ -66,7 +66,6 @@ __device__ inline half2_t pki4_to_half2(pk_i4_t q)
return amd_assembly_pk_add_f16(bit_cast<half2_t>(lo), bit_cast<half2_t>(SUB)); return amd_assembly_pk_add_f16(bit_cast<half2_t>(lo), bit_cast<half2_t>(SUB));
#endif #endif
} }
struct PassThroughPack8 struct PassThroughPack8
...@@ -87,10 +86,14 @@ struct PassThroughPack8 ...@@ -87,10 +86,14 @@ struct PassThroughPack8
vector_type<half_t, 8> dst; vector_type<half_t, 8> dst;
vector_type<pk_i4_t, 4> src{x}; vector_type<pk_i4_t, 4> src{x};
dst.template AsType<half2_t>()(Number<0>{}) = pki4_to_half2(src.template AsType<pk_i4_t>()[Number<0>{}]); dst.template AsType<half2_t>()(Number<0>{}) =
dst.template AsType<half2_t>()(Number<1>{}) = pki4_to_half2(src.template AsType<pk_i4_t>()[Number<1>{}]); pki4_to_half2(src.template AsType<pk_i4_t>()[Number<0>{}]);
dst.template AsType<half2_t>()(Number<2>{}) = pki4_to_half2(src.template AsType<pk_i4_t>()[Number<2>{}]); dst.template AsType<half2_t>()(Number<1>{}) =
dst.template AsType<half2_t>()(Number<3>{}) = pki4_to_half2(src.template AsType<pk_i4_t>()[Number<3>{}]); pki4_to_half2(src.template AsType<pk_i4_t>()[Number<1>{}]);
dst.template AsType<half2_t>()(Number<2>{}) =
pki4_to_half2(src.template AsType<pk_i4_t>()[Number<2>{}]);
dst.template AsType<half2_t>()(Number<3>{}) =
pki4_to_half2(src.template AsType<pk_i4_t>()[Number<3>{}]);
y = dst.template AsType<half8_t>()[Number<0>{}]; y = dst.template AsType<half8_t>()[Number<0>{}];
#endif #endif
......
...@@ -1370,7 +1370,6 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -1370,7 +1370,6 @@ struct GridwiseGemm_xdl_cshuffle_v3
c_thread_buf, c_thread_buf,
num_k_block_main_loop); num_k_block_main_loop);
// shuffle C and write out // shuffle C and write out
{ {
static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
......
...@@ -1025,8 +1025,7 @@ struct ThreadwiseTensorSliceTransfer_v4 ...@@ -1025,8 +1025,7 @@ struct ThreadwiseTensorSliceTransfer_v4
if constexpr(is_same_v<remove_cvref_t<SrcData>, pk_i4_t>) if constexpr(is_same_v<remove_cvref_t<SrcData>, pk_i4_t>)
{ {
static_assert(SrcScalarPerVector % PackedSize == 0, static_assert(SrcScalarPerVector % PackedSize == 0, "pk data N cannot be 1");
"pk data N cannot be 1");
} }
} }
...@@ -1126,8 +1125,9 @@ struct ThreadwiseTensorSliceTransfer_v4 ...@@ -1126,8 +1125,9 @@ struct ThreadwiseTensorSliceTransfer_v4
using src_vector_t = typename decltype(src_tmp_vector)::type; using src_vector_t = typename decltype(src_tmp_vector)::type;
//const bool is_src_valid = coordinate_has_valid_offset_assuming_visible_index_is_valid( // const bool is_src_valid =
//src_desc, src_data_coord); // coordinate_has_valid_offset_assuming_visible_index_is_valid( src_desc,
// src_data_coord);
const bool is_src_valid = true; const bool is_src_valid = true;
// copy data from src_buf into src_tmp_vector // copy data from src_buf into src_tmp_vector
......
...@@ -82,12 +82,11 @@ struct ThreadwiseTensorSliceTransfer_v3r1 ...@@ -82,12 +82,11 @@ struct ThreadwiseTensorSliceTransfer_v3r1
static_assert(is_same_v<remove_cvref_t<SrcData>, remove_cvref_t<DstData>>, static_assert(is_same_v<remove_cvref_t<SrcData>, remove_cvref_t<DstData>>,
"SrcData != DstData"); "SrcData != DstData");
static_assert(SrcScalarPerVector_ % PackedSize == 0 && DstScalarPerVector_ % PackedSize == 0, static_assert(SrcScalarPerVector_ % PackedSize == 0 &&
DstScalarPerVector_ % PackedSize == 0,
"SrcScalarPerVector_ and DstScalarPerVector_ cannot be 1"); "SrcScalarPerVector_ and DstScalarPerVector_ cannot be 1");
static_assert( static_assert(SrcVectorDim == DstVectorDim, "pk_i4_t does not support transpose");
SrcVectorDim == DstVectorDim,
"pk_i4_t does not support transpose");
} }
} }
...@@ -875,8 +874,8 @@ struct ThreadwiseTensorSliceTransfer_v3r1 ...@@ -875,8 +874,8 @@ struct ThreadwiseTensorSliceTransfer_v3r1
private: private:
static constexpr auto src_thread_scratch_desc_ = decltype(GetSrcThreadScratchDescriptor()){}; static constexpr auto src_thread_scratch_desc_ = decltype(GetSrcThreadScratchDescriptor()){};
//static constexpr auto src_oob_thread_scratch_desc_ = // static constexpr auto src_oob_thread_scratch_desc_ =
//decltype(GetSrcThreadScratchDescriptor()){}; // decltype(GetSrcThreadScratchDescriptor()){};
static constexpr auto dst_thread_scratch_desc_ = decltype(GetDstThreadScratchDescriptor()){}; static constexpr auto dst_thread_scratch_desc_ = decltype(GetDstThreadScratchDescriptor()){};
using SrcThreadScratch = using SrcThreadScratch =
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment