"...composable_kernel.git" did not exist on "ea932fd3491a10a9e5f9812a50d15e2c155f545f"
Commit 5e127c69 authored by Jing Zhang's avatar Jing Zhang
Browse files

add read/write all

parent 6cf0fa5c
...@@ -354,8 +354,8 @@ struct GridwiseDynamicGemm_km_kn_mn_v3 ...@@ -354,8 +354,8 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
} }
#endif #endif
#if 1
// output: register to global memory // output: register to global memory
#if 1
{ {
constexpr auto HoPerThreadx2 = HoPerThread * 2; constexpr auto HoPerThreadx2 = HoPerThread * 2;
constexpr auto WoPerThreadx2 = WoPerThread * 2; constexpr auto WoPerThreadx2 = WoPerThread * 2;
...@@ -452,6 +452,106 @@ struct GridwiseDynamicGemm_km_kn_mn_v3 ...@@ -452,6 +452,106 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
} }
} }
} }
#else
{
constexpr auto HoPerThreadx2 = HoPerThread * 2;
constexpr auto WoPerThreadx2 = WoPerThread * 2;
const index_t hox2_block_data_on_global = ho_block_work_id * HoPerBlock * 2;
const index_t wox2_block_data_on_global = wo_block_work_id * WoPerBlock * 2;
const index_t hox2_thread_data_on_global =
hox2_block_data_on_global + ho_thread_id * HoPerThreadx2;
const index_t wox2_thread_data_on_global =
wox2_block_data_on_global + wo_thread_id * WoPerThreadx2;
static_assert(KPerThread % CThreadTransferDstScalarPerVector == 0, "");
// static_assert(CThreadTransferDstScalarPerVector == 16, "");
constexpr auto KPerThreadAdd = KPerThread / CThreadTransferDstScalarPerVector;
const index_t k_block_data_on_global_add =
k_block_work_id * KPerBlock / CThreadTransferDstScalarPerVector;
const index_t k_thread_data_on_global_add =
k_block_data_on_global_add + k_thread_id * KPerThreadAdd;
constexpr auto d_k_n_hox2_wox2_thread_desc =
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(Number<KPerThreadAdd>{},
Number<1>{},
Number<HoPerThreadx2>{},
Number<WoPerThreadx2>{}));
constexpr auto vector_len = KPerThread * HoPerThreadx2 * WoPerThreadx2;
constexpr auto c_k_n_ho_wo_global_tensor_iterator_hacks = CGlobalIteratorHacks{};
vector_type<int8_t, vector_len> d_vec;
auto d_threadwise_transfer = ThreadwiseDynamicTensorSliceTransfer_v2<
FloatC,
decltype(d_vec),
decltype(d_k_n_hox2_wox2_global_desc),
decltype(d_k_n_hox2_wox2_thread_desc),
Sequence<KPerThreadAdd, 1, HoPerThreadx2, WoPerThreadx2>,
CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim,
// CThreadTransferDstScalarPerVector,
1,
AddressSpace::Global,
AddressSpace::Vgpr,
InMemoryDataOperation::Set,
1,
true>(d_k_n_hox2_wox2_global_desc,
make_multi_index(k_thread_data_on_global_add,
0,
hox2_thread_data_on_global,
wox2_thread_data_on_global));
auto c_threadwise_transfer = ThreadwiseDynamicTensorSliceTransfer_v1r3<
decltype(d_vec),
FloatC,
decltype(d_k_n_hox2_wox2_thread_desc),
decltype(d_k_n_hox2_wox2_global_desc),
Sequence<KPerThreadAdd, 1, HoPerThreadx2, WoPerThreadx2>,
CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim,
// CThreadTransferDstScalarPerVector,
1,
AddressSpace::Vgpr,
AddressSpace::Global,
CGlobalMemoryDataOperation,
1,
true>(d_k_n_hox2_wox2_global_desc,
make_multi_index(k_thread_data_on_global_add,
0,
hox2_thread_data_on_global,
wox2_thread_data_on_global));
d_threadwise_transfer.Run2(d_k_n_hox2_wox2_global_desc,
p_d_global,
d_k_n_hox2_wox2_thread_desc,
make_tuple(I0, I0, I0, I0),
d_vec,
c_k_n_ho_wo_global_tensor_iterator_hacks);
static_for<0, vector_len, 1>{}([&](auto i) {
constexpr auto kpack_i = i % (CThreadTransferDstScalarPerVector);
constexpr auto khw_i = i / (CThreadTransferDstScalarPerVector);
constexpr auto k_i = khw_i / (HoPerThreadx2 * WoPerThreadx2);
constexpr auto hw_i = khw_i % (HoPerThreadx2 * WoPerThreadx2);
constexpr auto h_i = hw_i / WoPerThreadx2;
constexpr auto w_i = hw_i % WoPerThreadx2;
d_vec.template AsType<int8_t>()(i) =
p_c_thread[c_k_n_ho_wo_thread_desc.CalculateOffset(make_tuple(
k_i * CThreadTransferDstScalarPerVector + kpack_i, 0, h_i / 2, w_i / 2))];
});
c_threadwise_transfer.Run2(d_k_n_hox2_wox2_thread_desc,
make_tuple(I0, I0, I0, I0),
d_vec,
d_k_n_hox2_wox2_global_desc,
p_c_global,
c_k_n_ho_wo_global_tensor_iterator_hacks);
}
#endif #endif
} }
......
...@@ -403,6 +403,75 @@ struct vector_type<T, 16> ...@@ -403,6 +403,75 @@ struct vector_type<T, 16>
} }
}; };
template <typename T>
struct vector_type<T, 256>
{
using d1_t = T;
typedef T d256_t __attribute__((ext_vector_type(256)));
typedef T d16_t __attribute__((ext_vector_type(16)));
using type = d256_t;
union
{
d256_t d256_;
StaticallyIndexedArray<d1_t, 256> d1x256_;
StaticallyIndexedArray<d16_t, 1> d16x16_;
StaticallyIndexedArray<d256_t, 1> d16x1_;
} data_;
__host__ __device__ constexpr vector_type() : data_{type{0}} {}
__host__ __device__ constexpr vector_type(type v) : data_{v} {}
template <typename X>
__host__ __device__ constexpr const auto& AsType() const
{
static_assert(is_same<X, d1_t>::value ||
is_same<X, d16_t>::value ||
is_same<X, d256_t>::value,
"wrong!");
if constexpr(is_same<X, d1_t>::value)
{
return data_.d1x256_;
}
else if constexpr(is_same<X, d16_t>::value)
{
return data_.d16x16_;
}
else if constexpr(is_same<X, d256_t>::value)
{
return data_.d256x1_;
}
}
template <typename X>
__host__ __device__ constexpr auto& AsType()
{
static_assert(is_same<X, d1_t>::value ||
is_same<X, d16_t>::value ||
is_same<X, d256_t>::value,
"wrong!");
if constexpr(is_same<X, d1_t>::value)
{
return data_.d1x256_;
}
else if constexpr(is_same<X, d16_t>::value)
{
return data_.d16x16_;
}
else if constexpr(is_same<X, d256_t>::value)
{
return data_.d256x1_;
}
}
};
// fp32 // fp32
using float2_t = typename vector_type<float, 2>::type; using float2_t = typename vector_type<float, 2>::type;
using float4_t = typename vector_type<float, 4>::type; using float4_t = typename vector_type<float, 4>::type;
...@@ -429,6 +498,7 @@ using int8x2_t = typename vector_type<int8_t, 2>::type; ...@@ -429,6 +498,7 @@ using int8x2_t = typename vector_type<int8_t, 2>::type;
using int8x4_t = typename vector_type<int8_t, 4>::type; using int8x4_t = typename vector_type<int8_t, 4>::type;
using int8x8_t = typename vector_type<int8_t, 8>::type; using int8x8_t = typename vector_type<int8_t, 8>::type;
using int8x16_t = typename vector_type<int8_t, 16>::type; using int8x16_t = typename vector_type<int8_t, 16>::type;
using int8x256_t = typename vector_type<int8_t, 256>::type;
// data type conversion // data type conversion
template <typename T> template <typename T>
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment