Commit e871c55b authored by Jing Zhang's avatar Jing Zhang
Browse files

fixed copy

parent c6e072a6
...@@ -367,7 +367,12 @@ struct GridwiseDynamicGemm_km_kn_mn_v2 ...@@ -367,7 +367,12 @@ struct GridwiseDynamicGemm_km_kn_mn_v2
static_assert(c_k_n_ho_wo_thread_desc_vec.GetElementSpaceSize() == 4, ""); static_assert(c_k_n_ho_wo_thread_desc_vec.GetElementSpaceSize() == 4, "");
FloatC d_vec[c_k_n_ho_wo_thread_desc_vec.GetElementSpaceSize()]; const index_t vec_len = c_k_n_ho_wo_thread_desc_vec.GetElementSpaceSize() *
CThreadTransferDstScalarPerVector;
vector_type<int8_t, vec_len> d_vec;
// FloatC d_vec[c_k_n_ho_wo_thread_desc_vec.GetElementSpaceSize()];
constexpr auto c_k_n_ho_wo_global_tensor_iterator_hacks = CGlobalIteratorHacks{}; constexpr auto c_k_n_ho_wo_global_tensor_iterator_hacks = CGlobalIteratorHacks{};
...@@ -376,13 +381,13 @@ struct GridwiseDynamicGemm_km_kn_mn_v2 ...@@ -376,13 +381,13 @@ struct GridwiseDynamicGemm_km_kn_mn_v2
static_for<0, WoPerThread, 1>{}([&](auto w_i) { static_for<0, WoPerThread, 1>{}([&](auto w_i) {
vector_type<int8_t, CThreadTransferDstScalarPerVector> t; vector_type<int8_t, CThreadTransferDstScalarPerVector> t;
// t.template AsType<FloatC>()(Number<0>{}) = d_vec.template AsType< t.template AsType<FloatC>()(Number<0>{}) = d_vec.template AsType<
// FloatC>()[Number<c_k_n_ho_wo_thread_desc_vec.CalculateOffset( FloatC>()[Number<c_k_n_ho_wo_thread_desc_vec.CalculateOffset(
// make_tuple(k_i, 0, h_i, w_i))>{}]; make_tuple(k_i, 0, h_i, w_i))>{}];
t.template AsType<FloatC>()(Number<0>{}) = // t.template AsType<FloatC>()(Number<0>{}) =
d_vec[Number<c_k_n_ho_wo_thread_desc_vec.CalculateOffset( // d_vec[Number<c_k_n_ho_wo_thread_desc_vec.CalculateOffset(
make_tuple(k_i, 0, h_i, w_i))>{}]; // make_tuple(k_i, 0, h_i, w_i))>{}];
static_for<0, CThreadTransferDstScalarPerVector, 1>{}([&](auto i) { static_for<0, CThreadTransferDstScalarPerVector, 1>{}([&](auto i) {
t.template AsType<int8_t>()(i) = t.template AsType<int8_t>()(i) =
...@@ -390,18 +395,19 @@ struct GridwiseDynamicGemm_km_kn_mn_v2 ...@@ -390,18 +395,19 @@ struct GridwiseDynamicGemm_km_kn_mn_v2
k_i * CThreadTransferDstScalarPerVector + i, 0, h_i, w_i))]; k_i * CThreadTransferDstScalarPerVector + i, 0, h_i, w_i))];
}); });
// d_vec.template AsType<FloatC>()( d_vec.template AsType<FloatC>()(
// Number<c_k_n_ho_wo_thread_desc_vec.CalculateOffset(make_tuple( Number<c_k_n_ho_wo_thread_desc_vec.CalculateOffset(make_tuple(
// k_i, 0, h_i, w_i))>{}) = t.template AsType<FloatC>()[Number<0>{}]; k_i, 0, h_i, w_i))>{}) = t.template AsType<FloatC>()[Number<0>{}];
d_vec[Number<c_k_n_ho_wo_thread_desc_vec.CalculateOffset(make_tuple( // d_vec[Number<c_k_n_ho_wo_thread_desc_vec.CalculateOffset(make_tuple(
k_i, 0, h_i, w_i))>{}] = t.template AsType<FloatC>()[Number<0>{}]; // k_i, 0, h_i, w_i))>{}] = t.template AsType<FloatC>()[Number<0>{}];
}); });
}); });
}); });
ThreadwiseDynamicTensorSliceTransfer_v1r3< ThreadwiseDynamicTensorSliceTransfer_v1r3<
FloatC, // FloatC,
decltype(d_vec),
FloatC, FloatC,
decltype(c_k_n_ho_wo_thread_desc_vec), decltype(c_k_n_ho_wo_thread_desc_vec),
decltype(c_k_n_ho_wo_global_desc), decltype(c_k_n_ho_wo_global_desc),
......
...@@ -377,7 +377,8 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3 ...@@ -377,7 +377,8 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3
src_desc.CalculateOffset(to_multi_index(src_slice_origin_idx) + dst_data_idx + src_desc.CalculateOffset(to_multi_index(src_slice_origin_idx) + dst_data_idx +
i * dst_scalar_step_in_vector); i * dst_scalar_step_in_vector);
dst_vector.template AsType<DstData>()(i) = p_src.template AsType<DstData>()[i]; dst_vector.template AsType<DstData>()(i) =
p_src.template AsType<DstData>()[Number<src_offset>{}];
}); });
const bool is_dst_valid = coordinate_has_valid_offset_assuming_visible_index_is_valid( const bool is_dst_valid = coordinate_has_valid_offset_assuming_visible_index_is_valid(
......
...@@ -403,6 +403,74 @@ struct vector_type<T, 16> ...@@ -403,6 +403,74 @@ struct vector_type<T, 16>
} }
}; };
template <typename T>
struct vector_type<T, 64>
{
using d1_t = T;
typedef T d64_t __attribute__((ext_vector_type(64)));
typedef T d16_t __attribute__((ext_vector_type(16)));
using type = d64_t;
union
{
d64_t d64_;
StaticallyIndexedArray<d1_t, 64> d1x64_;
StaticallyIndexedArray<d16_t, 16> d16x4_;
StaticallyIndexedArray<d64_t, 1> d16x1_;
} data_;
__host__ __device__ constexpr vector_type() : data_{type{0}} {}
__host__ __device__ constexpr vector_type(type v) : data_{v} {}
template <typename X>
__host__ __device__ constexpr const auto& AsType() const
{
static_assert(is_same<X, d1_t>::value ||
is_same<X, d16_t>::value ||
is_same<X, d64_t>::value,
"wrong!");
if constexpr(is_same<X, d1_t>::value)
{
return data_.d1x64_;
}
else if constexpr(is_same<X, d16_t>::value)
{
return data_.d16x4_;
}
else if constexpr(is_same<X, d64_t>::value)
{
return data_.d64x1_;
}
}
template <typename X>
__host__ __device__ constexpr auto& AsType()
{
static_assert(is_same<X, d1_t>::value ||
is_same<X, d16_t>::value ||
is_same<X, d64_t>::value,
"wrong!");
if constexpr(is_same<X, d1_t>::value)
{
return data_.d1x64_;
}
else if constexpr(is_same<X, d16_t>::value)
{
return data_.d16x4_;
}
else if constexpr(is_same<X, d64_t>::value)
{
return data_.d64x1_;
}
}
};
template <typename T> template <typename T>
struct vector_type<T, 256> struct vector_type<T, 256>
{ {
...@@ -467,9 +535,6 @@ struct vector_type<T, 256> ...@@ -467,9 +535,6 @@ struct vector_type<T, 256>
return data_.d256x1_; return data_.d256x1_;
} }
} }
}; };
// fp32 // fp32
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment