Commit 6aaa77c1 authored by Jing Zhang's avatar Jing Zhang
Browse files

finished an example

parent 7e734a03
......@@ -44,10 +44,10 @@ struct i32_to_i8
{
__host__ __device__ void operator()(I8& y, const I32& x) const
{
y = ck::type_convert<I8>(x) * scale;
y = ck::type_convert<I8>(ck::type_convert<float>(x) * reduced_amex_scale);
}
float scale = 1.0;
float reduced_amex_scale = 1.0;
};
using AElementOp = i32_to_i8;
......@@ -175,12 +175,15 @@ int main(int argc, char* argv[])
std::cout << "b_k_n: " << b_k_n.mDesc << std::endl;
std::cout << "e_m_n: " << e_m_n_host_result.mDesc << std::endl;
ADataType amax = 5;
BDataType bmax = 8;
switch(init_method)
{
case 0: break;
case 1:
a_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5});
b_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5});
a_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-amax, amax});
b_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-bmax, bmax});
break;
default:
a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
......@@ -195,8 +198,8 @@ int main(int argc, char* argv[])
b_device_buf.ToDevice(b_k_n.mData.data());
e_device_buf.ToDevice(e_m_n_device_result.mData.data());
auto a_element_op = AElementOp{0.2};
auto b_element_op = BElementOp{0.2};
auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{};
auto cde_element_op = CDEElementOp{};
// do GEMM
......@@ -254,8 +257,12 @@ int main(int argc, char* argv[])
auto ref_gemm = ReferenceGemmInstance{};
auto ref_invoker = ref_gemm.MakeInvoker();
auto ref_argument =
ref_gemm.MakeArgument(a_m_k, b_k_n, c_m_n, a_element_op, b_element_op, PassThrough{});
auto ref_argument = ref_gemm.MakeArgument(a_m_k,
b_k_n,
c_m_n,
AElementOp{static_cast<float>(1.0) / amax},
BElementOp{static_cast<float>(1.0) / bmax},
PassThrough{});
ref_invoker.Run(ref_argument);
......
......@@ -265,8 +265,9 @@ int main(int argc, char* argv[])
if(!reduce_1.IsSupportedArgument(argument_ptr_1.get()))
{
std::cout << "The runtime parameters seems supported by the DeviceReduce instance, exiting!"
<< std::endl;
std::cout
<< "The runtime parameters seems not supported by the DeviceReduce instance, exiting!"
<< std::endl;
};
auto invoker_ptr_1 = reduce_1.MakeInvokerPointer();
......
......@@ -73,8 +73,8 @@ using DeviceReduceInstance_2 = DeviceReduceMultiBlock<InOutDataType,
OutputIndex,
false, // HaveIndexInputIfOutputIndex
256,
128,
2,
32,
8,
1,
1,
1, // vector dim
......@@ -83,14 +83,12 @@ using DeviceReduceInstance_2 = DeviceReduceMultiBlock<InOutDataType,
static bool do_verify;
static int init_method;
static float alpha;
static float beta;
static bool time_kernel;
int main(int argc, char* argv[])
{
// used by the device reduction
const std::array<int, 1> reduceDims_1 = {0};
const std::array<int, 1> reduceDims_1 = {1};
const std::array<int, 1> reduceDims_2 = {0};
// used by the host reduction
......@@ -126,9 +124,6 @@ int main(int argc, char* argv[])
throw std::runtime_error(ostr.str());
};
alpha = 1.0f;
beta = 0.0f;
Tensor<InOutDataType> in_1(inLengths_1);
Tensor<InOutDataType> out_ref(outLengths);
......@@ -149,26 +144,12 @@ int main(int argc, char* argv[])
switch(init_method)
{
case 0: break;
case 1:
in_1.GenerateTensorValue(GeneratorTensor_1<InOutDataType>{1}, num_thread);
if(beta != 0.0f)
out_ref.GenerateTensorValue(GeneratorTensor_1<InOutDataType>{1}, num_thread);
break;
case 1: in_1.GenerateTensorValue(GeneratorTensor_1<InOutDataType>{1}, num_thread); break;
case 2:
in_1.GenerateTensorValue(GeneratorTensor_2<InOutDataType>{-5, 5}, num_thread);
if(beta != 0.0f)
out_ref.GenerateTensorValue(GeneratorTensor_2<InOutDataType>{-5, 5}, num_thread);
break;
default:
in_1.GenerateTensorValue(GeneratorTensor_3<InOutDataType>{-5.0, 5.0}, num_thread);
if(beta != 0.0f)
out_ref.GenerateTensorValue(GeneratorTensor_3<InOutDataType>{-5.0, 5.0},
num_thread);
default: in_1.GenerateTensorValue(GeneratorTensor_3<InOutDataType>{-5.0, 5.0}, num_thread);
}
if(beta != 0.0f)
for(size_t i = 0; i < out_ref.mDesc.GetElementSpaceSize(); i++)
out.mData[i] = out_ref.mData[i];
};
DeviceMem in_1_dev(sizeof(InOutDataType) * in_1.mDesc.GetElementSpaceSize());
......@@ -177,9 +158,6 @@ int main(int argc, char* argv[])
in_1_dev.ToDevice(in_1.mData.data());
if(beta != 0.0f)
out_dev.ToDevice(out.mData.data());
InElementwiseOperation in_elementwise_op;
AccElementwiseOperation acc_elementwise_op;
......@@ -222,8 +200,8 @@ int main(int argc, char* argv[])
arrOutLengths,
arrOutStrides,
reduceDims,
static_cast<double>(alpha),
static_cast<double>(beta),
1.0,
0.0,
in_1.mData.data(),
nullptr,
out_ref.mData.data(),
......@@ -261,8 +239,9 @@ int main(int argc, char* argv[])
if(!reduce_1.IsSupportedArgument(argument_ptr_1.get()))
{
std::cout << "The runtime parameters seems supported by the DeviceReduce instance, exiting!"
<< std::endl;
std::cout
<< "The runtime parameters seems not supported by the DeviceReduce instance, exiting!"
<< std::endl;
};
auto invoker_ptr_1 = reduce_1.MakeInvokerPointer();
......@@ -274,8 +253,8 @@ int main(int argc, char* argv[])
arrOutLengths,
arrOutStrides,
reduceDims_2,
static_cast<double>(alpha),
static_cast<double>(beta),
1.0,
0.0,
in_2_dev.GetDeviceBuffer(),
nullptr,
out_dev.GetDeviceBuffer(),
......
......@@ -20,6 +20,8 @@
#include "ck/tensor_operation/gpu/device/impl/device_reduce_multiblock.hpp"
#include "ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp"
#include "ck/utility/is_detected.hpp"
namespace ck {
template <typename GridwiseGemm,
......@@ -164,6 +166,8 @@ struct DeviceGemmMultipleDScaleAB_Xdl_CShuffle
{
using DeviceOp = DeviceGemmMultipleDScaleAB_Xdl_CShuffle;
using RowMajor = tensor_layout::gemm::RowMajor;
static constexpr index_t NumDTensor = DsDataType::Size();
static constexpr auto I0 = Number<0>{};
......@@ -177,7 +181,7 @@ struct DeviceGemmMultipleDScaleAB_Xdl_CShuffle
static auto MakeAGridDescriptor_M_K(index_t MRaw, index_t KRaw, index_t StrideA)
{
const auto a_grid_desc_mraw_kraw = [&]() {
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
if constexpr(is_same_v<RowMajor, ALayout>)
{
return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw),
make_tuple(StrideA, I1));
......@@ -195,7 +199,7 @@ struct DeviceGemmMultipleDScaleAB_Xdl_CShuffle
static auto MakeBGridDescriptor_N_K(index_t KRaw, index_t NRaw, index_t StrideB)
{
const auto b_grid_desc_nraw_kraw = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value)
if constexpr(is_same<RowMajor, BLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw),
make_tuple(I1, StrideB));
......@@ -214,7 +218,7 @@ struct DeviceGemmMultipleDScaleAB_Xdl_CShuffle
static auto MakeEGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideE)
{
const auto e_grid_desc_mraw_nraw = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, ELay>::value)
if constexpr(is_same<RowMajor, ELay>::value)
{
return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw),
make_tuple(StrideE, I1));
......@@ -425,7 +429,7 @@ struct DeviceGemmMultipleDScaleAB_Xdl_CShuffle
index_t KRaw_;
};
template <typename InOutDataType>
template <typename InOutDataType, typename Layout>
struct Reduce2D
{
static constexpr ReduceTensorOp ReduceOpId = ReduceTensorOp::AMAX;
......@@ -440,26 +444,27 @@ struct DeviceGemmMultipleDScaleAB_Xdl_CShuffle
using PassThroughOp = tensor_operation::element_wise::PassThrough;
using DeviceReduceInstance_1 = DeviceReduceMultiBlock<InOutDataType,
InOutDataType,
InOutDataType,
2, // Rank
1, // NumReduceDim
ReduceOperation,
InElementwiseOperation,
PassThroughOp,
InMemoryDataOperationEnum::Set,
PropagateNan,
OutputIndex,
false, // HaveIndexInputIfOutputIndex
256,
32,
8,
1,
1,
1, // vector dim
1,
1>;
using DeviceReduceInstance_1 =
DeviceReduceMultiBlock<InOutDataType,
InOutDataType,
InOutDataType,
2, // Rank
1, // NumReduceDim
ReduceOperation,
InElementwiseOperation,
PassThroughOp,
InMemoryDataOperationEnum::Set,
PropagateNan,
OutputIndex,
false, // HaveIndexInputIfOutputIndex
256,
32,
8,
1,
1,
is_same<RowMajor, Layout>::value ? 1 : 0, // vector dim
1,
1>;
using DeviceReduceInstance_2 = DeviceReduceMultiBlock<InOutDataType,
InOutDataType,
......@@ -473,9 +478,9 @@ struct DeviceGemmMultipleDScaleAB_Xdl_CShuffle
PropagateNan,
OutputIndex,
false, // HaveIndexInputIfOutputIndex
256,
128,
2,
256, // BlockSize
32, // MThreadClusterSize
8, // KThreadClusterSize
1,
1,
1, // vector dim
......@@ -493,7 +498,7 @@ struct DeviceGemmMultipleDScaleAB_Xdl_CShuffle
const std::array<int, 1> reduceDims_1 = {arrInLengths_1[0] > arrInLengths_1[1] ? 0 : 1};
const std::array<int, 1> reduceDims_2 = {0};
std::array<index_t, 1> arrInLengths_2{arrInLengths_1[reduceDims_1[0]]};
std::array<index_t, 1> arrInLengths_2{arrInLengths_1[!reduceDims_1[0]]};
std::array<index_t, 1> arrInStrides_2{1};
std::array<index_t, 1> arrOutLengths{1};
......@@ -520,9 +525,10 @@ struct DeviceGemmMultipleDScaleAB_Xdl_CShuffle
if(!reduce_1.IsSupportedArgument(argument_ptr_1.get()))
{
std::cout << "The runtime parameters seems supported by the DeviceReduce instance, "
"exiting!"
<< std::endl;
std::cout
<< "The runtime parameters seems not supported by the DeviceReduce instance, "
"exiting!"
<< std::endl;
};
auto invoker_ptr_1 = reduce_1.MakeInvokerPointer();
......@@ -564,6 +570,9 @@ struct DeviceGemmMultipleDScaleAB_Xdl_CShuffle
{
using Argument = DeviceOp::Argument;
template <typename T>
using has_reduced_amex_scale = decltype(std::declval<T&>().reduced_amex_scale);
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_,
......@@ -575,45 +584,67 @@ struct DeviceGemmMultipleDScaleAB_Xdl_CShuffle
throw std::runtime_error("wrong! GridwiseGemm has invalid setting");
}
using RowMajor = tensor_layout::gemm::RowMajor;
float kern_time = 0;
ADataType amax_a, amax_b;
auto reduce_a = Reduce2D<ADataType>{};
kern_time += reduce_a.Run({arg.MRaw_, arg.KRaw_},
is_same<RowMajor, ALayout>::value // A[M, K]
? std::array<index_t, 2>{arg.KRaw_, I1}
: std::array<index_t, 2>{I1, arg.MRaw_},
arg.p_a_grid_,
arg.p_e_grid_,
arg.p_e_grid_,
stream_config);
hipGetErrorString(hipMemcpyWithStream(&amax_a,
arg.p_e_grid_,
sizeof(ADataType),
hipMemcpyDeviceToHost,
stream_config.stream_id_));
auto reduce_b = Reduce2D<BDataType>{};
kern_time += reduce_b.Run({arg.KRaw_, arg.NRaw_},
is_same<RowMajor, BLayout>::value // B[K, N]
? std::array<index_t, 2>{arg.NRaw_, I1}
: std::array<index_t, 2>{I1, arg.KRaw_},
arg.p_a_grid_,
arg.p_e_grid_,
arg.p_e_grid_,
stream_config);
hipGetErrorString(hipMemcpyWithStream(&amax_b,
arg.p_e_grid_,
sizeof(ADataType),
hipMemcpyDeviceToHost,
stream_config.stream_id_));
// std::cout << "amax_a: " << amax_a << " amax_b: " << amax_b << std::endl;
AElementwiseOperation a_element_op_ = arg.a_element_op_;
if constexpr(is_detected<has_reduced_amex_scale, AElementwiseOperation>::value)
{
ADataType amax_a;
auto reduce_a = Reduce2D<ADataType, ALayout>{};
kern_time += reduce_a.Run({arg.MRaw_, arg.KRaw_},
is_same<RowMajor, ALayout>::value // A[M, K]
? std::array<index_t, 2>{arg.KRaw_, I1}
: std::array<index_t, 2>{I1, arg.MRaw_},
arg.p_a_grid_,
arg.p_e_grid_,
arg.p_e_grid_,
stream_config);
hipGetErrorString(hipMemcpyWithStream(&amax_a,
arg.p_e_grid_,
sizeof(ADataType),
hipMemcpyDeviceToHost,
stream_config.stream_id_));
static_assert(is_same<decltype(arg.a_element_op_.reduced_amex_scale), float>::value,
"scale is not float!");
a_element_op_.reduced_amex_scale = 1.0 / amax_a;
// std::cout << " amax_a: " << amax_a << std::endl;
}
BElementwiseOperation b_element_op_ = arg.b_element_op_;
if constexpr(is_detected<has_reduced_amex_scale, BElementwiseOperation>::value)
{
ADataType amax_b;
auto reduce_b = Reduce2D<BDataType, BLayout>{};
kern_time += reduce_b.Run({arg.KRaw_, arg.NRaw_},
is_same<RowMajor, BLayout>::value // B[K, N]
? std::array<index_t, 2>{arg.NRaw_, I1}
: std::array<index_t, 2>{I1, arg.KRaw_},
arg.p_b_grid_,
arg.p_e_grid_,
arg.p_e_grid_,
stream_config);
hipGetErrorString(hipMemcpyWithStream(&amax_b,
arg.p_e_grid_,
sizeof(BDataType),
hipMemcpyDeviceToHost,
stream_config.stream_id_));
static_assert(is_same<decltype(arg.b_element_op_.reduced_amex_scale), float>::value,
"scale is not float!");
b_element_op_.reduced_amex_scale = 1.0 / amax_b;
// std::cout << " amax_b: " << amax_b << std::endl;
}
const index_t grid_size =
arg.block_2_etile_map_.CalculateGridSize(arg.e_grid_desc_m_n_);
......@@ -646,8 +677,8 @@ struct DeviceGemmMultipleDScaleAB_Xdl_CShuffle
arg.p_b_grid_,
arg.p_ds_grid_,
arg.p_e_grid_,
arg.a_element_op_,
arg.b_element_op_,
a_element_op_,
b_element_op_,
arg.cde_element_op_,
arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment