Commit 3a84f68e authored by Jing Zhang's avatar Jing Zhang
Browse files

fixed

parent bf545630
...@@ -37,6 +37,7 @@ struct DeviceGemmV2 : public BaseOperator ...@@ -37,6 +37,7 @@ struct DeviceGemmV2 : public BaseOperator
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0; virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
virtual bool GetPermuteA() = 0;
virtual bool GetPermuteB() = 0; virtual bool GetPermuteB() = 0;
virtual ck::index_t GetKPerBlock() = 0; virtual ck::index_t GetKPerBlock() = 0;
}; };
......
...@@ -410,7 +410,6 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -410,7 +410,6 @@ struct GridwiseGemm_xdl_cshuffle_v3
// Pre-shuffled Weight // Pre-shuffled Weight
// BGlobal[K / KPerBlock, N, KPerBlock / K1, K1] -> BTile[K / K1, N, K1] // BGlobal[K / KPerBlock, N, KPerBlock / K1, K1] -> BTile[K / K1, N, K1]
constexpr index_t BK01 = KPerBlock / BK1Value; constexpr index_t BK01 = KPerBlock / BK1Value;
// const index_t BK00 = BK0 / BK01;
const index_t BK0_ = StrideB / BK1Value; const index_t BK0_ = StrideB / BK1Value;
const index_t BK00 = BK0_ / BK01; const index_t BK00 = BK0_ / BK01;
......
...@@ -1137,7 +1137,6 @@ struct ThreadwiseTensorSliceTransfer_v4 ...@@ -1137,7 +1137,6 @@ struct ThreadwiseTensorSliceTransfer_v4
} }
else if constexpr(SrcBuffer::IsStaticBuffer()) else if constexpr(SrcBuffer::IsStaticBuffer())
{ {
static_assert(false, "");
static_for<0, SrcScalarPerVector, 1>{}([&](auto i) { static_for<0, SrcScalarPerVector, 1>{}([&](auto i) {
constexpr index_t src_offset = src_desc.CalculateOffset( constexpr index_t src_offset = src_desc.CalculateOffset(
src_ref_to_origin_disp_idx + data_to_origin_disp_idx + src_ref_to_origin_disp_idx + data_to_origin_disp_idx +
......
...@@ -82,9 +82,9 @@ struct ThreadwiseTensorSliceTransfer_v3r1 ...@@ -82,9 +82,9 @@ struct ThreadwiseTensorSliceTransfer_v3r1
static_assert(is_same_v<remove_cvref_t<SrcData>, remove_cvref_t<DstData>>, static_assert(is_same_v<remove_cvref_t<SrcData>, remove_cvref_t<DstData>>,
"SrcData != DstData"); "SrcData != DstData");
static_assert(SrcScalarPerVector_ % PackedSize == 0 && static_assert(
DstScalarPerVector_ % PackedSize == 0, SrcScalarPerVector_ % PackedSize == 0 && DstScalarPerVector_ % PackedSize == 0,
"SrcScalarPerVector_ and DstScalarPerVector_ cannot be 1"); "SrcScalarPerVector_ and DstScalarPerVector_ cannot be 1 for packed data type");
static_assert(SrcVectorDim == DstVectorDim, "pk_i4_t does not support transpose"); static_assert(SrcVectorDim == DstVectorDim, "pk_i4_t does not support transpose");
} }
...@@ -234,8 +234,6 @@ struct ThreadwiseTensorSliceTransfer_v3r1 ...@@ -234,8 +234,6 @@ struct ThreadwiseTensorSliceTransfer_v3r1
using src_elem_op_vec_t = typename vector_type<SrcData, elem_op_vec_len>::type; using src_elem_op_vec_t = typename vector_type<SrcData, elem_op_vec_len>::type;
using dst_elem_op_vec_t = typename vector_type<DstData, elem_op_vec_len>::type; using dst_elem_op_vec_t = typename vector_type<DstData, elem_op_vec_len>::type;
static_assert(elem_op_vec_len == 1, "elem_op_vec_len != 1");
auto src_vector_container = src_vector_type{ auto src_vector_container = src_vector_type{
src_buf.template Get<src_vector_t>(src_coord_.GetOffset() / PackedSize, true)}; src_buf.template Get<src_vector_t>(src_coord_.GetOffset() / PackedSize, true)};
...@@ -300,13 +298,10 @@ struct ThreadwiseTensorSliceTransfer_v3r1 ...@@ -300,13 +298,10 @@ struct ThreadwiseTensorSliceTransfer_v3r1
TransferDataFromSrcThreadScratchToDstThreadScratch(Number<ThreadScratchId> thread_scratch_id) TransferDataFromSrcThreadScratchToDstThreadScratch(Number<ThreadScratchId> thread_scratch_id)
{ {
#if !CK_EXPERIMENTAL_USE_IN_REGISTER_SUB_DWORD_TRANSPOSE #if !CK_EXPERIMENTAL_USE_IN_REGISTER_SUB_DWORD_TRANSPOSE
static_assert(false, "");
static_ford<SliceLengths>{}([&](auto idx) { static_ford<SliceLengths>{}([&](auto idx) {
dst_thread_scratch_(idx) = src_thread_scratch_tuple_[thread_scratch_id][idx]; dst_thread_scratch_(idx) = src_thread_scratch_tuple_[thread_scratch_id][idx];
}); });
#else #else
#if 1
// OOB Check // OOB Check
constexpr auto src_scalar_per_access = generate_sequence( constexpr auto src_scalar_per_access = generate_sequence(
detail::lambda_scalar_per_access<SrcVectorDim, SrcScalarPerVector_>{}, Number<nDim>{}); detail::lambda_scalar_per_access<SrcVectorDim, SrcScalarPerVector_>{}, Number<nDim>{});
...@@ -369,7 +364,6 @@ struct ThreadwiseTensorSliceTransfer_v3r1 ...@@ -369,7 +364,6 @@ struct ThreadwiseTensorSliceTransfer_v3r1
src_thread_scratch_tuple_(thread_scratch_id) src_thread_scratch_tuple_(thread_scratch_id)
.template SetAsType<vector_t>(src_data_idx_seq, op_r_v); .template SetAsType<vector_t>(src_data_idx_seq, op_r_v);
}); });
#endif
// sub-dword transpose between src_thread_scratch_ and dst_thread_scratch_ // sub-dword transpose between src_thread_scratch_ and dst_thread_scratch_
// TODO make this logic more generic for more sub-dword datatype // TODO make this logic more generic for more sub-dword datatype
...@@ -381,9 +375,8 @@ struct ThreadwiseTensorSliceTransfer_v3r1 ...@@ -381,9 +375,8 @@ struct ThreadwiseTensorSliceTransfer_v3r1
(is_same<f8_t, remove_cvref_t<DstData>>::value && (is_same<f8_t, remove_cvref_t<DstData>>::value &&
SrcScalarPerVector % 4 == 0 && DstScalarPerVector % 4 == 0))) SrcScalarPerVector % 4 == 0 && DstScalarPerVector % 4 == 0)))
{ {
// static_assert(is_same_v<remove_cvref_t<SrcData>, pk_i4_t>, static_assert(is_same_v<remove_cvref_t<SrcData>, pk_i4_t>,
//"transpose is not allowed for pk_i4_t"); "transpose is not allowed for pk_i4_t");
#if 1
// each transpose does // each transpose does
// DstScalarPerVector # of src vectors in src_thread_scratch_ // DstScalarPerVector # of src vectors in src_thread_scratch_
// SrcScalarPerVector # of dst vectors in dst_thread_scratch_ // SrcScalarPerVector # of dst vectors in dst_thread_scratch_
...@@ -441,7 +434,6 @@ struct ThreadwiseTensorSliceTransfer_v3r1 ...@@ -441,7 +434,6 @@ struct ThreadwiseTensorSliceTransfer_v3r1
transpose_vectors<DstData, DstScalarPerVector, SrcScalarPerVector>{}( transpose_vectors<DstData, DstScalarPerVector, SrcScalarPerVector>{}(
src_vector_refs, dst_vector_refs); src_vector_refs, dst_vector_refs);
}); });
#endif
} }
else else
{ {
......
...@@ -429,6 +429,7 @@ __device__ typename vector_type<T, N>::type amd_buffer_load_impl(int32x4_t src_w ...@@ -429,6 +429,7 @@ __device__ typename vector_type<T, N>::type amd_buffer_load_impl(int32x4_t src_w
(is_same<T, f8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || (is_same<T, f8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(is_same<T, bf8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || (is_same<T, bf8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(is_same<T, int8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || (is_same<T, int8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(is_same<T, uint8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(is_same<T, pk_i4_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)), (is_same<T, pk_i4_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)),
"wrong! not implemented"); "wrong! not implemented");
......
...@@ -1893,6 +1893,14 @@ using bf8x32_t = bf8x32_fnuz_t; ...@@ -1893,6 +1893,14 @@ using bf8x32_t = bf8x32_fnuz_t;
using bf8x64_t = bf8x64_fnuz_t; using bf8x64_t = bf8x64_fnuz_t;
#endif #endif
// u8
using uint8x2_t = typename vector_type<uint8_t, 2>::type;
using uint8x4_t = typename vector_type<uint8_t, 4>::type;
using uint8x8_t = typename vector_type<uint8_t, 8>::type;
using uint8x16_t = typename vector_type<uint8_t, 16>::type;
using uint8x32_t = typename vector_type<uint8_t, 32>::type;
using uint8x64_t = typename vector_type<uint8_t, 64>::type;
// pack int4 // pack int4
using pk_i4x2_t = typename vector_type<pk_i4_t, 2>::type; using pk_i4x2_t = typename vector_type<pk_i4_t, 2>::type;
using pk_i4x4_t = typename vector_type<pk_i4_t, 4>::type; using pk_i4x4_t = typename vector_type<pk_i4_t, 4>::type;
......
...@@ -29,13 +29,6 @@ struct DynamicBuffer ...@@ -29,13 +29,6 @@ struct DynamicBuffer
ElementSpaceSize element_space_size_; ElementSpaceSize element_space_size_;
T invalid_element_value_ = T{0}; T invalid_element_value_ = T{0};
static constexpr index_t PackedSize = []() {
if constexpr(is_same_v<remove_cvref_t<T>, pk_i4_t>)
return 2;
else
return 1;
}();
__host__ __device__ constexpr DynamicBuffer(T* p_data, ElementSpaceSize element_space_size) __host__ __device__ constexpr DynamicBuffer(T* p_data, ElementSpaceSize element_space_size)
: p_data_{p_data}, element_space_size_{element_space_size} : p_data_{p_data}, element_space_size_{element_space_size}
{ {
...@@ -59,7 +52,11 @@ struct DynamicBuffer ...@@ -59,7 +52,11 @@ struct DynamicBuffer
__host__ __device__ constexpr T& operator()(index_t i) { return p_data_[i]; } __host__ __device__ constexpr T& operator()(index_t i) { return p_data_[i]; }
template <typename X> template <typename X,
typename enable_if<is_same<typename scalar_type<remove_cvref_t<X>>::type,
typename scalar_type<remove_cvref_t<T>>::type>::value ||
!is_native_type<X>(),
bool>::type = false>
__host__ __device__ constexpr auto Get(index_t i, bool is_valid_element) const __host__ __device__ constexpr auto Get(index_t i, bool is_valid_element) const
{ {
// X contains multiple T // X contains multiple T
...@@ -85,18 +82,14 @@ struct DynamicBuffer ...@@ -85,18 +82,14 @@ struct DynamicBuffer
return amd_buffer_load_invalid_element_return_zero<remove_cvref_t<T>, return amd_buffer_load_invalid_element_return_zero<remove_cvref_t<T>,
t_per_x, t_per_x,
coherence>( coherence>(
p_data_, i, is_valid_element, element_space_size_ / PackedSize); p_data_, i, is_valid_element, element_space_size_);
} }
else else
{ {
return amd_buffer_load_invalid_element_return_customized_value<remove_cvref_t<T>, return amd_buffer_load_invalid_element_return_customized_value<remove_cvref_t<T>,
t_per_x, t_per_x,
coherence>( coherence>(
p_data_, p_data_, i, is_valid_element, element_space_size_, invalid_element_value_);
i,
is_valid_element,
element_space_size_ / PackedSize,
invalid_element_value_);
} }
} }
else else
...@@ -198,10 +191,14 @@ struct DynamicBuffer ...@@ -198,10 +191,14 @@ struct DynamicBuffer
dst_buf.p_data_, dst_buf.p_data_,
dst_offset, dst_offset,
is_valid_element, is_valid_element,
element_space_size_ / PackedSize); element_space_size_);
} }
template <typename X> template <typename X,
typename enable_if<is_same<typename scalar_type<remove_cvref_t<X>>::type,
typename scalar_type<remove_cvref_t<T>>::type>::value ||
!is_native_type<X>(),
bool>::type = false>
__host__ __device__ void Set(index_t i, bool is_valid_element, const X& x) __host__ __device__ void Set(index_t i, bool is_valid_element, const X& x)
{ {
// X contains multiple T // X contains multiple T
...@@ -229,7 +226,7 @@ struct DynamicBuffer ...@@ -229,7 +226,7 @@ struct DynamicBuffer
constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector; constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
amd_buffer_store<remove_cvref_t<T>, t_per_x, coherence>( amd_buffer_store<remove_cvref_t<T>, t_per_x, coherence>(
x, p_data_, i, is_valid_element, element_space_size_ / PackedSize); x, p_data_, i, is_valid_element, element_space_size_);
} }
else if constexpr(GetAddressSpace() == AddressSpaceEnum::Lds && else if constexpr(GetAddressSpace() == AddressSpaceEnum::Lds &&
is_same<typename scalar_type<remove_cvref_t<T>>::type, int8_t>::value && is_same<typename scalar_type<remove_cvref_t<T>>::type, int8_t>::value &&
...@@ -381,7 +378,7 @@ struct DynamicBuffer ...@@ -381,7 +378,7 @@ struct DynamicBuffer
constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector; constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
amd_buffer_atomic_add<remove_cvref_t<T>, t_per_x>( amd_buffer_atomic_add<remove_cvref_t<T>, t_per_x>(
x, p_data_, i, is_valid_element, element_space_size_ / PackedSize); x, p_data_, i, is_valid_element, element_space_size_);
} }
else else
{ {
...@@ -420,7 +417,7 @@ struct DynamicBuffer ...@@ -420,7 +417,7 @@ struct DynamicBuffer
constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector; constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
amd_buffer_atomic_max<remove_cvref_t<T>, t_per_x>( amd_buffer_atomic_max<remove_cvref_t<T>, t_per_x>(
x, p_data_, i, is_valid_element, element_space_size_ / PackedSize); x, p_data_, i, is_valid_element, element_space_size_);
} }
else if(is_valid_element) else if(is_valid_element)
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment