Commit 65cfb2a1 authored by Jing Zhang's avatar Jing Zhang
Browse files

format

parent 398f8851
...@@ -55,8 +55,8 @@ __device__ inline half2_t pki4_to_half2(pk_i4_t q) ...@@ -55,8 +55,8 @@ __device__ inline half2_t pki4_to_half2(pk_i4_t q)
#else #else
uint8_t x_u8 = ck::bit_cast<uint8_t>(q); uint8_t x_u8 = ck::bit_cast<uint8_t>(q);
int x_l = (x_u8 & 0x0f); int x_l = (x_u8 & 0x0f);
int x_h = (x_u8 & 0xf0) << 12; int x_h = (x_u8 & 0xf0) << 12;
const int EX = 0x64006400; const int EX = 0x64006400;
...@@ -66,7 +66,6 @@ __device__ inline half2_t pki4_to_half2(pk_i4_t q) ...@@ -66,7 +66,6 @@ __device__ inline half2_t pki4_to_half2(pk_i4_t q)
return amd_assembly_pk_add_f16(bit_cast<half2_t>(lo), bit_cast<half2_t>(SUB)); return amd_assembly_pk_add_f16(bit_cast<half2_t>(lo), bit_cast<half2_t>(SUB));
#endif #endif
} }
struct PassThroughPack8 struct PassThroughPack8
...@@ -87,12 +86,16 @@ struct PassThroughPack8 ...@@ -87,12 +86,16 @@ struct PassThroughPack8
vector_type<half_t, 8> dst; vector_type<half_t, 8> dst;
vector_type<pk_i4_t, 4> src{x}; vector_type<pk_i4_t, 4> src{x};
dst.template AsType<half2_t>()(Number<0>{}) = pki4_to_half2(src.template AsType<pk_i4_t>()[Number<0>{}]); dst.template AsType<half2_t>()(Number<0>{}) =
dst.template AsType<half2_t>()(Number<1>{}) = pki4_to_half2(src.template AsType<pk_i4_t>()[Number<1>{}]); pki4_to_half2(src.template AsType<pk_i4_t>()[Number<0>{}]);
dst.template AsType<half2_t>()(Number<2>{}) = pki4_to_half2(src.template AsType<pk_i4_t>()[Number<2>{}]); dst.template AsType<half2_t>()(Number<1>{}) =
dst.template AsType<half2_t>()(Number<3>{}) = pki4_to_half2(src.template AsType<pk_i4_t>()[Number<3>{}]); pki4_to_half2(src.template AsType<pk_i4_t>()[Number<1>{}]);
dst.template AsType<half2_t>()(Number<2>{}) =
pki4_to_half2(src.template AsType<pk_i4_t>()[Number<2>{}]);
dst.template AsType<half2_t>()(Number<3>{}) =
pki4_to_half2(src.template AsType<pk_i4_t>()[Number<3>{}]);
y = dst.template AsType<half8_t>()[Number<0>{}]; y = dst.template AsType<half8_t>()[Number<0>{}];
#endif #endif
} }
......
...@@ -1370,7 +1370,6 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -1370,7 +1370,6 @@ struct GridwiseGemm_xdl_cshuffle_v3
c_thread_buf, c_thread_buf,
num_k_block_main_loop); num_k_block_main_loop);
// shuffle C and write out // shuffle C and write out
{ {
static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
......
...@@ -1025,8 +1025,7 @@ struct ThreadwiseTensorSliceTransfer_v4 ...@@ -1025,8 +1025,7 @@ struct ThreadwiseTensorSliceTransfer_v4
if constexpr(is_same_v<remove_cvref_t<SrcData>, pk_i4_t>) if constexpr(is_same_v<remove_cvref_t<SrcData>, pk_i4_t>)
{ {
static_assert(SrcScalarPerVector % PackedSize == 0, static_assert(SrcScalarPerVector % PackedSize == 0, "pk data N cannot be 1");
"pk data N cannot be 1");
} }
} }
...@@ -1126,8 +1125,9 @@ struct ThreadwiseTensorSliceTransfer_v4 ...@@ -1126,8 +1125,9 @@ struct ThreadwiseTensorSliceTransfer_v4
using src_vector_t = typename decltype(src_tmp_vector)::type; using src_vector_t = typename decltype(src_tmp_vector)::type;
//const bool is_src_valid = coordinate_has_valid_offset_assuming_visible_index_is_valid( // const bool is_src_valid =
//src_desc, src_data_coord); // coordinate_has_valid_offset_assuming_visible_index_is_valid( src_desc,
// src_data_coord);
const bool is_src_valid = true; const bool is_src_valid = true;
// copy data from src_buf into src_tmp_vector // copy data from src_buf into src_tmp_vector
......
...@@ -80,14 +80,13 @@ struct ThreadwiseTensorSliceTransfer_v3r1 ...@@ -80,14 +80,13 @@ struct ThreadwiseTensorSliceTransfer_v3r1
if constexpr(is_same_v<remove_cvref_t<SrcData>, pk_i4_t>) if constexpr(is_same_v<remove_cvref_t<SrcData>, pk_i4_t>)
{ {
static_assert(is_same_v<remove_cvref_t<SrcData>, remove_cvref_t<DstData>>, static_assert(is_same_v<remove_cvref_t<SrcData>, remove_cvref_t<DstData>>,
"SrcData != DstData"); "SrcData != DstData");
static_assert(SrcScalarPerVector_ % PackedSize == 0 && DstScalarPerVector_ % PackedSize == 0, static_assert(SrcScalarPerVector_ % PackedSize == 0 &&
"SrcScalarPerVector_ and DstScalarPerVector_ cannot be 1"); DstScalarPerVector_ % PackedSize == 0,
"SrcScalarPerVector_ and DstScalarPerVector_ cannot be 1");
static_assert( static_assert(SrcVectorDim == DstVectorDim, "pk_i4_t does not support transpose");
SrcVectorDim == DstVectorDim,
"pk_i4_t does not support transpose");
} }
} }
...@@ -446,7 +445,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1 ...@@ -446,7 +445,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1
else else
{ {
constexpr auto packed_per_access = generate_sequence( constexpr auto packed_per_access = generate_sequence(
detail::lambda_scalar_per_access<SrcVectorDim, PackedSize>{}, Number<nDim>{}); detail::lambda_scalar_per_access<SrcVectorDim, PackedSize>{}, Number<nDim>{});
constexpr auto packed_access_lengths = SliceLengths{} / packed_per_access; constexpr auto packed_access_lengths = SliceLengths{} / packed_per_access;
...@@ -875,8 +874,8 @@ struct ThreadwiseTensorSliceTransfer_v3r1 ...@@ -875,8 +874,8 @@ struct ThreadwiseTensorSliceTransfer_v3r1
private: private:
static constexpr auto src_thread_scratch_desc_ = decltype(GetSrcThreadScratchDescriptor()){}; static constexpr auto src_thread_scratch_desc_ = decltype(GetSrcThreadScratchDescriptor()){};
//static constexpr auto src_oob_thread_scratch_desc_ = // static constexpr auto src_oob_thread_scratch_desc_ =
//decltype(GetSrcThreadScratchDescriptor()){}; // decltype(GetSrcThreadScratchDescriptor()){};
static constexpr auto dst_thread_scratch_desc_ = decltype(GetDstThreadScratchDescriptor()){}; static constexpr auto dst_thread_scratch_desc_ = decltype(GetDstThreadScratchDescriptor()){};
using SrcThreadScratch = using SrcThreadScratch =
......
...@@ -82,7 +82,7 @@ struct ReferenceGemm : public device::BaseOperator ...@@ -82,7 +82,7 @@ struct ReferenceGemm : public device::BaseOperator
i4 = (i4x2 >> 0) & 0xf; i4 = (i4x2 >> 0) & 0xf;
else else
i4 = (i4x2 >> 4) & 0xf; i4 = (i4x2 >> 4) & 0xf;
i4 = i4 - 8; i4 = i4 - 8;
v_a = type_convert<ComputeTypeA>(i4); v_a = type_convert<ComputeTypeA>(i4);
} }
else else
...@@ -103,7 +103,7 @@ struct ReferenceGemm : public device::BaseOperator ...@@ -103,7 +103,7 @@ struct ReferenceGemm : public device::BaseOperator
i4 = (i4x2 >> 0) & 0xf; i4 = (i4x2 >> 0) & 0xf;
else else
i4 = (i4x2 >> 4) & 0xf; i4 = (i4x2 >> 4) & 0xf;
i4 = i4 - 8; i4 = i4 - 8;
v_b = type_convert<ComputeTypeB>(i4); v_b = type_convert<ComputeTypeB>(i4);
} }
else else
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment