Commit b00ae5df authored by Chao Liu's avatar Chao Liu
Browse files

update im2col

parent 95268357
#include "tile_program.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_gemm.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_batched_gemm_gemm_xdl_cshuffle_v1.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/host_utility/io.hpp"
#include "ck/library/utility/device_memory.hpp"
// program
struct GemmMultiplD
{
__host__ __device__ void operator()(TileProgram& tp, int x, int y)
{
auto desc = tp.make_naive_tensor_descriptor_packed(ck::make_tuple(x));
printf("length %d\n", desc.GetLength(ck::Number<0>{}));
}
};
int main()
{
int x = 100;
int y = 101;
launch(HelloWorld{}, 1, 1, x, y);
return 0;
}
......@@ -229,50 +229,63 @@ struct Im2Col
constexpr auto I0 = Number<0>{};
#if 1 // debug
const auto a_src_desc = tensor_operation::TransformConvFwdToGemm<
NDimSpatial,
tensor_operation::device::ConvolutionForwardSpecialization::Default>::
template MakeADescriptor_M_K<ALayout>(a_g_n_c_wis_lengths,
a_g_n_c_wis_strides,
b_g_k_c_xs_lengths,
b_g_k_c_xs_strides,
c_g_n_k_wos_lengths,
c_g_n_k_wos_strides,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads);
#else
const auto a_src_desc =
ps(tensor_operation::TransformConvFwdToGemm<
NDimSpatial,
tensor_operation::device::ConvolutionForwardSpecialization::Default>::
template MakeADescriptor_M_K<ALayout>(a_g_n_c_wis_lengths,
a_g_n_c_wis_strides,
b_g_k_c_xs_lengths,
b_g_k_c_xs_strides,
c_g_n_k_wos_lengths,
c_g_n_k_wos_strides,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads));
#endif
#if 1 // debug
const auto a_dst_desc = make_naive_tensor_descriptor(
const index_t N = a_g_n_c_wis_lengths[1];
const index_t C = a_g_n_c_wis_lengths[2];
const index_t Hi = a_g_n_c_wis_lengths[3];
const index_t Wi = a_g_n_c_wis_lengths[4];
const index_t Ho = c_g_n_k_wos_lengths[3];
const index_t Wo = c_g_n_k_wos_lengths[4];
const index_t ConvStrideH = conv_filter_strides[0];
const index_t ConvStrideW = conv_filter_strides[1];
const index_t Y = b_g_k_c_xs_lengths[3];
const index_t X = b_g_k_c_xs_lengths[4];
const index_t ConvDilationH = conv_filter_dilations[0];
const index_t ConvDilationW = conv_filter_dilations[1];
const index_t InLeftPadH = input_left_pads[0];
const index_t InLeftPadW = input_left_pads[1];
const index_t InRightPadH = input_right_pads[0];
const index_t InRightPadW = input_right_pads[1];
const auto a_n_hi_wi_c = make_naive_tensor_packed<AddressSpaceEnum::Global, true>(
make_tuple(N, Hi, Wi, C), p_a_img);
const auto a_n_hip_wip_c = transform_tensor(
a_n_hi_wi_c,
make_tuple(make_pass_through_transform(N),
make_pad_transform(Hi, InLeftPadH, InRightPadH),
make_pad_transform(Wi, InLeftPadW, InRightPadW),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
const auto a_n_y_ho_x_wo_c = transform_tensor(
a_n_hip_wip_c,
make_tuple(
make_pass_through_transform(N),
make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)),
make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
const auto a_src =
transform_tensor(a_n_y_ho_x_wo_c,
make_tuple(ps(make_merge_transform(make_tuple(N, Ho, Wo))),
ps(make_merge_transform(make_tuple(Y, X, C)))),
make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
auto a_dst = make_naive_tensor<AddressSpaceEnum::Global, true>(
make_tuple(a_gemmm_gemmk_lengths[0], a_gemmm_gemmk_lengths[1]),
make_tuple(a_gemmm_gemmk_strides[0], a_gemmm_gemmk_strides[1]));
#else
const auto a_dst_desc = ps(make_naive_tensor_descriptor(
make_tuple(a_gemmm_gemmk_lengths[0], a_gemmm_gemmk_lengths[1]),
make_tuple(a_gemmm_gemmk_strides[0], a_gemmm_gemmk_strides[1])));
#endif
const auto a_src = make_tensor<AddressSpaceEnum::Global, true>(a_src_desc, p_a_img);
auto a_dst = make_tensor<AddressSpaceEnum::Global, true>(a_dst_desc, p_a_mtx);
make_tuple(a_gemmm_gemmk_strides[0], a_gemmm_gemmk_strides[1]),
p_a_mtx);
const auto num_gemmm = a_gemmm_gemmk_lengths[0];
const auto num_gemmk = a_gemmm_gemmk_lengths[1];
......@@ -281,7 +294,7 @@ struct Im2Col
const auto num_tile_m = ps.read_first_lane(num_gemmm / kMPerTile);
#if 1 // debug
#if 0 // debug
const auto block2tile = make_cluster_descriptor(make_tuple(num_tile_m));
#else
const auto block2tile = ps(make_cluster_descriptor(make_tuple(num_tile_m)));
......
......@@ -14,11 +14,11 @@ struct Tensor
using TensorDescriptor = remove_cvref_t<TensorDescTmp>;
using DataType = remove_reference_t<T>;
static constexpr AddressSpaceEnum kAdressSpace_ = AddressSpace;
static constexpr AddressSpaceEnum kAddressSpace_ = AddressSpace;
static constexpr bool kInvalidElementUseNumericalZeroValue_ =
InvalidElementUseNumericalZeroValue;
__host__ __device__ constexpr Tensor() : buf_{nullptr, 0}, desc_{} {}
__host__ __device__ constexpr Tensor() : buf_{}, desc_{} {}
__host__ __device__ constexpr Tensor(DataType* p_data, TensorDescriptor desc)
: buf_{p_data, desc.GetElementSpaceSize()}, desc_{desc}
......@@ -52,6 +52,36 @@ __host__ __device__ constexpr auto make_tensor(const TensorDesc& desc, T* p_data
p_data, desc};
}
template <AddressSpaceEnum AddressSpace,
bool InvalidElementUseNumericalZeroValue,
typename T,
typename... Lengths,
typename... Strides,
typename enable_if<sizeof...(Lengths) == sizeof...(Strides), bool>::type = false>
__host__ __device__ constexpr auto make_naive_tensor(const Tuple<Lengths...>& lengths,
const Tuple<Strides...>& strides,
T* p_data,
T invalid_element_value = 0)
{
auto desc = make_naive_tensor_descriptor(lengths, strides);
return Tensor<AddressSpace, InvalidElementUseNumericalZeroValue, T, decltype(desc)>{
p_data, desc, invalid_element_value};
}
template <AddressSpaceEnum AddressSpace,
bool InvalidElementUseNumericalZeroValue,
typename T,
typename... Lengths>
__host__ __device__ constexpr auto
make_naive_tensor_packed(const Tuple<Lengths...>& lengths, T* p_data, T invalid_element_value = 0)
{
auto desc = make_naive_tensor_descriptor_packed(lengths);
return Tensor<AddressSpace, InvalidElementUseNumericalZeroValue, T, decltype(desc)>{
p_data, desc, invalid_element_value};
}
template <typename OldTensor,
typename NewTransforms,
typename NewLowerDimensionOldVisibleIdss,
......@@ -61,13 +91,13 @@ __host__ __device__ constexpr auto transform_tensor(const OldTensor& old_tensor,
NewLowerDimensionOldVisibleIdss,
NewUpperDimensionNewVisibleIdss)
{
const auto new_desc = transform_tensor(old_tensor.desc_,
const auto new_desc = transform_tensor_descriptor(old_tensor.desc_,
new_transforms,
NewLowerDimensionOldVisibleIdss{},
NewUpperDimensionNewVisibleIdss{});
return Tensor<OldTensor::kAddressSpace_,
OldTensor::kInvalidElementUseNumericalZeroValue,
OldTensor::kInvalidElementUseNumericalZeroValue_,
typename OldTensor::DataType,
remove_cvref_t<decltype(new_desc)>>{old_tensor.buf_.p_data_, new_desc};
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment