"examples/community/lpw_stable_diffusion_xl.py" did not exist on "12a232efa99d7a8c33f54ae515c5a3d6fc5c8f34"
Commit 332f9039 authored by Jing Zhang's avatar Jing Zhang
Browse files

vector add and out

parent 03aa52bc
...@@ -47,7 +47,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad ...@@ -47,7 +47,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad
const InRightPads& in_right_pads, const InRightPads& in_right_pads,
const FloatAB* __restrict__ p_wei_global, const FloatAB* __restrict__ p_wei_global,
const FloatAB* __restrict__ p_in_global, const FloatAB* __restrict__ p_in_global,
const FloatC* __restrict__ p_d_global, FloatAB* __restrict__ p_d_global,
FloatC* __restrict__ p_out_global) const FloatC* __restrict__ p_out_global) const
{ {
constexpr auto I0 = Number<0>{}; constexpr auto I0 = Number<0>{};
...@@ -271,7 +271,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad ...@@ -271,7 +271,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad
decltype(in_e_n_ho_wo_global_desc), decltype(in_e_n_ho_wo_global_desc),
const FloatAB*, const FloatAB*,
decltype(add_k_n_hopx2_wopx2_global_desc), decltype(add_k_n_hopx2_wopx2_global_desc),
const FloatC*, FloatAB*,
decltype(out_k_n_hop_wop_global_desc), decltype(out_k_n_hop_wop_global_desc),
FloatC*, FloatC*,
integral_constant<bool, true>, integral_constant<bool, true>,
...@@ -302,7 +302,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad ...@@ -302,7 +302,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad
decltype(in_e_n_ho_wo_global_desc), decltype(in_e_n_ho_wo_global_desc),
const FloatAB*, const FloatAB*,
decltype(add_k_n_hopx2_wopx2_global_desc), decltype(add_k_n_hopx2_wopx2_global_desc),
const FloatC*, FloatAB*,
decltype(out_k_n_hop_wop_global_desc), decltype(out_k_n_hop_wop_global_desc),
FloatC*, FloatC*,
integral_constant<bool, true>, integral_constant<bool, true>,
...@@ -333,7 +333,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad ...@@ -333,7 +333,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad
decltype(in_e_n_ho_wo_global_desc), decltype(in_e_n_ho_wo_global_desc),
const FloatAB*, const FloatAB*,
decltype(add_k_n_hopx2_wopx2_global_desc), decltype(add_k_n_hopx2_wopx2_global_desc),
const FloatC*, FloatAB*,
decltype(out_k_n_hop_wop_global_desc), decltype(out_k_n_hop_wop_global_desc),
FloatC*, FloatC*,
integral_constant<bool, false>, integral_constant<bool, false>,
...@@ -364,7 +364,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad ...@@ -364,7 +364,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad
decltype(in_e_n_ho_wo_global_desc), decltype(in_e_n_ho_wo_global_desc),
const FloatAB*, const FloatAB*,
decltype(add_k_n_hopx2_wopx2_global_desc), decltype(add_k_n_hopx2_wopx2_global_desc),
const FloatC*, FloatAB*,
decltype(out_k_n_hop_wop_global_desc), decltype(out_k_n_hop_wop_global_desc),
FloatC*, FloatC*,
integral_constant<bool, false>, integral_constant<bool, false>,
......
...@@ -75,7 +75,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v3 ...@@ -75,7 +75,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
const BGlobalDesc& b_e_n_ho_wo_global_desc, const BGlobalDesc& b_e_n_ho_wo_global_desc,
const FloatAB* __restrict__ p_b_global, const FloatAB* __restrict__ p_b_global,
const DGlobalDesc& d_k_n_hox2_wox2_global_desc, const DGlobalDesc& d_k_n_hox2_wox2_global_desc,
const FloatC* __restrict__ p_d_global, FloatAB* __restrict__ p_d_global,
const CGlobalDesc& c_k_n_ho_wo_global_desc, const CGlobalDesc& c_k_n_ho_wo_global_desc,
FloatC* __restrict__ p_c_global, FloatC* __restrict__ p_c_global,
FloatAB* __restrict__ p_shared_block, FloatAB* __restrict__ p_shared_block,
...@@ -174,6 +174,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v3 ...@@ -174,6 +174,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
const index_t wo_thread_data_on_global = const index_t wo_thread_data_on_global =
wo_block_data_on_global + wo_thread_id * WoPerThread; wo_block_data_on_global + wo_thread_id * WoPerThread;
#if 0
// A matrix blockwise copy // A matrix blockwise copy
auto a_blockwise_copy = auto a_blockwise_copy =
BlockwiseDynamicTensorSliceTransfer_v4<BlockSize, BlockwiseDynamicTensorSliceTransfer_v4<BlockSize,
...@@ -267,7 +268,6 @@ struct GridwiseDynamicGemm_km_kn_mn_v3 ...@@ -267,7 +268,6 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
index_t b_block_data_begin = 0; index_t b_block_data_begin = 0;
#if 1
if constexpr(HasMainKBlockLoop) if constexpr(HasMainKBlockLoop)
{ {
FloatAB* p_b_thread_even = p_b_thread_double; FloatAB* p_b_thread_even = p_b_thread_double;
...@@ -365,35 +365,44 @@ struct GridwiseDynamicGemm_km_kn_mn_v3 ...@@ -365,35 +365,44 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
hox2_block_data_on_global + ho_thread_id * HoPerThreadx2; hox2_block_data_on_global + ho_thread_id * HoPerThreadx2;
const index_t wox2_thread_data_on_global = const index_t wox2_thread_data_on_global =
wox2_block_data_on_global + wo_thread_id * WoPerThreadx2; wox2_block_data_on_global + wo_thread_id * WoPerThreadx2;
const index_t k_thread_data_on_global =
k_block_data_on_global + k_thread_id * KPerThread; static_assert(KPerThread % 16 == 0, "");
constexpr auto KPerThreadAdd = KPerThread / 16;
const index_t k_block_data_on_global_add = k_block_work_id * KPerBlock / 16;
const index_t k_thread_data_on_global_add =
k_block_data_on_global_add + k_thread_id * KPerThreadAdd;
constexpr auto d_k_n_hox2_wox2_thread_desc = constexpr auto d_k_n_hox2_wox2_thread_desc =
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(Number<KPerThread>{}, make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(Number<KPerThreadAdd>{},
Number<1>{}, Number<1>{},
Number<HoPerThreadx2>{}, Number<HoPerThreadx2>{},
Number<WoPerThreadx2>{})); Number<WoPerThreadx2>{}));
FloatC p_d_thread[d_k_n_hox2_wox2_thread_desc.GetElementSpaceSize()]; FloatAB p_d_thread[d_k_n_hox2_wox2_thread_desc.GetElementSpaceSize()];
constexpr auto vector_len = sizeof(FloatAB) / sizeof(FloatC);
static_assert(vector_len == 16);
constexpr auto c_k_n_ho_wo_global_tensor_iterator_hacks = CGlobalIteratorHacks{}; constexpr auto c_k_n_ho_wo_global_tensor_iterator_hacks = CGlobalIteratorHacks{};
#if 1 #if 1
ThreadwiseDynamicTensorSliceTransfer_v2< ThreadwiseDynamicTensorSliceTransfer_v2<
FloatC, FloatAB,
FloatC, FloatAB,
decltype(d_k_n_hox2_wox2_global_desc), decltype(d_k_n_hox2_wox2_global_desc),
decltype(d_k_n_hox2_wox2_thread_desc), decltype(d_k_n_hox2_wox2_thread_desc),
Sequence<KPerThread, 1, HoPerThreadx2, WoPerThreadx2>, Sequence<KPerThreadAdd, 1, HoPerThreadx2, WoPerThreadx2>,
CThreadTransferSrcDstAccessOrder, CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim, CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector, // CThreadTransferDstScalarPerVector,
1,
AddressSpace::Global, AddressSpace::Global,
AddressSpace::Vgpr, AddressSpace::Vgpr,
InMemoryDataOperation::Set, InMemoryDataOperation::Set,
1, 1,
true>(d_k_n_hox2_wox2_global_desc, true>(d_k_n_hox2_wox2_global_desc,
make_multi_index(k_thread_data_on_global, make_multi_index(k_thread_data_on_global_add,
0, 0,
hox2_thread_data_on_global, hox2_thread_data_on_global,
wox2_thread_data_on_global)) wox2_thread_data_on_global))
...@@ -406,17 +415,17 @@ struct GridwiseDynamicGemm_km_kn_mn_v3 ...@@ -406,17 +415,17 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
#endif #endif
#if 1 #if 0
for(index_t k_i = 0; k_i < KPerThread; ++k_i) for(index_t k_i = 0; k_i < KPerThreadAdd; ++k_i)
{ {
for(index_t h_i = 0; h_i < HoPerThreadx2; ++h_i) for(index_t h_i = 0; h_i < HoPerThreadx2; ++h_i)
{ {
for(index_t w_i = 0; w_i < WoPerThreadx2; ++w_i) for(index_t w_i = 0; w_i < WoPerThreadx2; ++w_i)
{ {
p_d_thread[d_k_n_hox2_wox2_thread_desc.CalculateOffset( p_d_thread[d_k_n_hox2_wox2_thread_desc.CalculateOffset(
make_tuple(k_i, 0, h_i, w_i))] += make_tuple(k_i, 0, h_i, w_i))] += 1;
p_c_thread[c_k_n_ho_wo_thread_desc.CalculateOffset( //p_c_thread[c_k_n_ho_wo_thread_desc.CalculateOffset(
make_tuple(k_i, 0, h_i / 2, w_i / 2))]; //make_tuple(k_i, 0, h_i / 2, w_i / 2))];
} }
} }
} }
...@@ -424,20 +433,21 @@ struct GridwiseDynamicGemm_km_kn_mn_v3 ...@@ -424,20 +433,21 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
#if 1 #if 1
ThreadwiseDynamicTensorSliceTransfer_v1r3< ThreadwiseDynamicTensorSliceTransfer_v1r3<
FloatC, FloatAB,
FloatC, FloatAB,
decltype(d_k_n_hox2_wox2_thread_desc), decltype(d_k_n_hox2_wox2_thread_desc),
decltype(d_k_n_hox2_wox2_global_desc), decltype(d_k_n_hox2_wox2_global_desc),
Sequence<KPerThread, 1, HoPerThreadx2, WoPerThreadx2>, Sequence<KPerThreadAdd, 1, HoPerThreadx2, WoPerThreadx2>,
CThreadTransferSrcDstAccessOrder, CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim, CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector, // CThreadTransferDstScalarPerVector,
1,
AddressSpace::Vgpr, AddressSpace::Vgpr,
AddressSpace::Global, AddressSpace::Global,
CGlobalMemoryDataOperation, CGlobalMemoryDataOperation,
1, 1,
true>(d_k_n_hox2_wox2_global_desc, true>(d_k_n_hox2_wox2_global_desc,
make_multi_index(k_thread_data_on_global, make_multi_index(k_thread_data_on_global_add,
0, 0,
hox2_thread_data_on_global, hox2_thread_data_on_global,
wox2_thread_data_on_global)) wox2_thread_data_on_global))
...@@ -445,7 +455,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v3 ...@@ -445,7 +455,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
make_tuple(I0, I0, I0, I0), make_tuple(I0, I0, I0, I0),
p_d_thread, p_d_thread,
d_k_n_hox2_wox2_global_desc, d_k_n_hox2_wox2_global_desc,
p_c_global, p_d_global,
c_k_n_ho_wo_global_tensor_iterator_hacks); c_k_n_ho_wo_global_tensor_iterator_hacks);
#endif #endif
} }
...@@ -458,7 +468,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v3 ...@@ -458,7 +468,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
const BGlobalDesc& b_e_n_ho_wo_global_desc, const BGlobalDesc& b_e_n_ho_wo_global_desc,
const FloatAB* __restrict__ p_b_global, const FloatAB* __restrict__ p_b_global,
const DGlobalDesc& d_k_n_hox2_wox2_global_desc, const DGlobalDesc& d_k_n_hox2_wox2_global_desc,
const FloatC* __restrict__ p_d_global, FloatAB* __restrict__ p_d_global,
const CGlobalDesc& c_k_n_ho_wo_global_desc, const CGlobalDesc& c_k_n_ho_wo_global_desc,
FloatC* __restrict__ p_c_global, FloatC* __restrict__ p_c_global,
integral_constant<bool, HasMainKBlockLoop>, integral_constant<bool, HasMainKBlockLoop>,
...@@ -488,7 +498,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v3 ...@@ -488,7 +498,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
const BGlobalDesc* p_b_e_n_ho_wo_global_desc, const BGlobalDesc* p_b_e_n_ho_wo_global_desc,
const FloatAB* __restrict__ p_b_global, const FloatAB* __restrict__ p_b_global,
const DGlobalDesc& d_k_n_hox2_wox2_global_desc, const DGlobalDesc& d_k_n_hox2_wox2_global_desc,
const FloatC* __restrict__ p_d_global, FloatAB* __restrict__ p_d_global,
const CGlobalDesc* p_c_k_n_ho_wo_global_desc, const CGlobalDesc* p_c_k_n_ho_wo_global_desc,
FloatC* __restrict__ p_c_global, FloatC* __restrict__ p_c_global,
integral_constant<bool, HasMainKBlockLoop>, integral_constant<bool, HasMainKBlockLoop>,
...@@ -517,7 +527,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v3 ...@@ -517,7 +527,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
const void* p_b_e_n_ho_wo_global_desc, const void* p_b_e_n_ho_wo_global_desc,
const FloatAB* __restrict__ p_b_global, const FloatAB* __restrict__ p_b_global,
const DGlobalDesc& d_k_n_hox2_wox2_global_desc, const DGlobalDesc& d_k_n_hox2_wox2_global_desc,
const FloatC* __restrict__ p_d_global, FloatAB* __restrict__ p_d_global,
const void* p_c_k_n_ho_wo_global_desc, const void* p_c_k_n_ho_wo_global_desc,
FloatC* __restrict__ p_c_global, FloatC* __restrict__ p_c_global,
integral_constant<bool, HasMainKBlockLoop>, integral_constant<bool, HasMainKBlockLoop>,
......
...@@ -367,6 +367,7 @@ __device__ void amd_buffer_store_impl_v2(const typename vector_type<T, N>::type ...@@ -367,6 +367,7 @@ __device__ void amd_buffer_store_impl_v2(const typename vector_type<T, N>::type
(is_same<T, float>::value && (N == 1 || N == 2 || N == 4)) || (is_same<T, float>::value && (N == 1 || N == 2 || N == 4)) ||
(is_same<T, int32_t>::value && (N == 1 || N == 2 || N == 4)) || (is_same<T, int32_t>::value && (N == 1 || N == 2 || N == 4)) ||
(is_same<T, int8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || (is_same<T, int8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(is_same<T, int32x4_t>::value && (N == 1)) ||
(is_same<T, half_t>::value && (N == 1 || N == 2 || N == 4 || N == 8)), (is_same<T, half_t>::value && (N == 1 || N == 2 || N == 4 || N == 8)),
"wrong! not implemented"); "wrong! not implemented");
...@@ -467,6 +468,14 @@ __device__ void amd_buffer_store_impl_v2(const typename vector_type<T, N>::type ...@@ -467,6 +468,14 @@ __device__ void amd_buffer_store_impl_v2(const typename vector_type<T, N>::type
0); 0);
} }
} }
else if constexpr(is_same<T, int32x4_t>::value)
{
__llvm_amdgcn_raw_buffer_store_i32x4(src_thread_data,
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
0);
}
else if constexpr(is_same<T, half_t>::value) else if constexpr(is_same<T, half_t>::value)
{ {
if constexpr(N == 1) if constexpr(N == 1)
......
...@@ -92,7 +92,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw( ...@@ -92,7 +92,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw(
const auto out_n_k0_ho_wo_k1_desc = const auto out_n_k0_ho_wo_k1_desc =
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N, K0, Ho, Wo, K1)); make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N, K0, Ho, Wo, K1));
const auto add_n_k0_hox2_wox2_k1_desc = const auto add_n_k0_hox2_wox2_k1_desc =
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N, K0, Hox2, Wox2, K1)); make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N, K0, Hox2, Wox2, 1));
const auto conv_strides = sequence_to_tuple_of_number(ConvStrides{}); const auto conv_strides = sequence_to_tuple_of_number(ConvStrides{});
const auto conv_dilations = sequence_to_tuple_of_number(ConvDilations{}); const auto conv_dilations = sequence_to_tuple_of_number(ConvDilations{});
...@@ -220,7 +220,8 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw( ...@@ -220,7 +220,8 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw(
wei_k_c_y_x_device_buf.GetDeviceBuffer()), wei_k_c_y_x_device_buf.GetDeviceBuffer()),
static_cast<typename vector_type<TInWei, InWeiVectorSize>::type*>( static_cast<typename vector_type<TInWei, InWeiVectorSize>::type*>(
in_n_c_hi_wi_device_buf.GetDeviceBuffer()), in_n_c_hi_wi_device_buf.GetDeviceBuffer()),
static_cast<TOut*>(add_n_k_hox2_wox2_device_buf.GetDeviceBuffer()), static_cast<typename vector_type<TInWei, InWeiVectorSize>::type*>(
add_n_k_hox2_wox2_device_buf.GetDeviceBuffer()),
static_cast<TOut*>(out_n_k_hox2_wox2_device_buf.GetDeviceBuffer())); static_cast<TOut*>(out_n_k_hox2_wox2_device_buf.GetDeviceBuffer()));
out_n_k_hox2_wox2_device_buf.FromDevice(out_n_k0_hox2_wox2_k1.mData.data()); out_n_k_hox2_wox2_device_buf.FromDevice(out_n_k0_hox2_wox2_k1.mData.data());
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment