Unverified Commit 86d1b46a authored by Mingtao Gu's avatar Mingtao Gu Committed by GitHub
Browse files

fix a bug for int4 scale weight only kernel (#1820)


Co-authored-by: default avatarmtgu0705 <mtgu@amd.com>
parent bdddf1ea
......@@ -19,8 +19,6 @@ struct pk_i4_t
type data;
__host__ __device__ constexpr pk_i4_t() : data{type{}} {}
__host__ __device__ constexpr pk_i4_t(type init) : data{init} {}
__host__ __device__ constexpr operator float() const { return static_cast<int8_t>(data); }
};
inline constexpr auto next_pow2(uint32_t x)
......
......@@ -29,6 +29,13 @@ struct DynamicBuffer
ElementSpaceSize element_space_size_;
T invalid_element_value_ = T{0};
static constexpr index_t PackedSize = []() {
if constexpr(is_same_v<remove_cvref_t<T>, pk_i4_t>)
return 2;
else
return 1;
}();
__host__ __device__ constexpr DynamicBuffer(T* p_data, ElementSpaceSize element_space_size)
: p_data_{p_data}, element_space_size_{element_space_size}
{
......@@ -82,14 +89,18 @@ struct DynamicBuffer
return amd_buffer_load_invalid_element_return_zero<remove_cvref_t<T>,
t_per_x,
coherence>(
p_data_, i, is_valid_element, element_space_size_);
p_data_, i, is_valid_element, element_space_size_ / PackedSize);
}
else
{
return amd_buffer_load_invalid_element_return_customized_value<remove_cvref_t<T>,
t_per_x,
coherence>(
p_data_, i, is_valid_element, element_space_size_, invalid_element_value_);
p_data_,
i,
is_valid_element,
element_space_size_ / PackedSize,
invalid_element_value_);
}
}
else
......@@ -191,7 +202,7 @@ struct DynamicBuffer
dst_buf.p_data_,
dst_offset,
is_valid_element,
element_space_size_);
element_space_size_ / PackedSize);
}
template <typename X,
......@@ -226,7 +237,7 @@ struct DynamicBuffer
constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
amd_buffer_store<remove_cvref_t<T>, t_per_x, coherence>(
x, p_data_, i, is_valid_element, element_space_size_);
x, p_data_, i, is_valid_element, element_space_size_ / PackedSize);
}
else if constexpr(GetAddressSpace() == AddressSpaceEnum::Lds &&
is_same<typename scalar_type<remove_cvref_t<T>>::type, int8_t>::value &&
......@@ -378,7 +389,7 @@ struct DynamicBuffer
constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
amd_buffer_atomic_add<remove_cvref_t<T>, t_per_x>(
x, p_data_, i, is_valid_element, element_space_size_);
x, p_data_, i, is_valid_element, element_space_size_ / PackedSize);
}
else
{
......@@ -417,7 +428,7 @@ struct DynamicBuffer
constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
amd_buffer_atomic_max<remove_cvref_t<T>, t_per_x>(
x, p_data_, i, is_valid_element, element_space_size_);
x, p_data_, i, is_valid_element, element_space_size_ / PackedSize);
}
else if(is_valid_element)
{
......
......@@ -342,7 +342,7 @@ bool profile_gemm_b_scale_impl(int do_verification,
if(do_log)
{
LogRangeAsType<float>(std::cout << "a : ", a_m_k.mData, ",") << std::endl;
LogRangeAsType<int8_t>(std::cout << "b: ", b_k_n.mData, ",") << std::endl;
LogRangeAsType<float>(std::cout << "b: ", b_k_n.mData, ",") << std::endl;
LogRangeAsType<float>(
std::cout << "c_host : ", c_m_n_host_result.mData, ",")
<< std::endl;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment