Commit 81b26528 authored by Chao Liu's avatar Chao Liu
Browse files

added bias add; worked around compiler issues

parent 4f2c8bce
......@@ -50,7 +50,7 @@ template <typename SrcData,
typename DstData,
typename SrcDesc,
typename DstDesc,
typename SrcElementwiseOperation,
typename DstElementwiseOperation,
typename SliceLengths,
typename DimAccessOrder,
index_t DstVectorDim,
......@@ -72,9 +72,9 @@ struct ThreadwiseTensorSliceTransfer_v1r3
__device__ constexpr ThreadwiseTensorSliceTransfer_v1r3(
const DstDesc& dst_desc,
const Index& dst_slice_origin_idx,
const SrcElementwiseOperation src_element_op)
const DstElementwiseOperation& dst_element_op)
: dst_coord_(make_tensor_coordinate(dst_desc, dst_slice_origin_idx)),
src_element_op_{src_element_op}
dst_element_op_{dst_element_op}
{
static_assert(SrcDesc::IsKnownAtCompileTime(),
"wrong! SrcDesc need to known at compile-time");
......@@ -201,7 +201,7 @@ struct ThreadwiseTensorSliceTransfer_v1r3
// apply element-wise operation and type convert
dst_vector.template AsType<DstData>()(i) =
type_convert<DstData>(src_element_op_(src_buf[Number<src_offset>{}]));
type_convert<DstData>(dst_element_op_(src_buf[Number<src_offset>{}]));
});
const bool is_dst_valid =
......@@ -378,7 +378,7 @@ struct ThreadwiseTensorSliceTransfer_v1r3
private:
DstCoord dst_coord_;
SrcElementwiseOperation src_element_op_;
const DstElementwiseOperation dst_element_op_;
}; // namespace ck
// Assume:
......
......@@ -32,7 +32,7 @@ template <typename SrcData,
typename DstDesc,
typename Dst0Desc, // this is really one of sources, but it has same shape as DstDesc
typename Dst1Desc, // this is really one of sources, but it has same shape as DstDesc
typename SrcElementwiseOperation,
typename DstElementwiseOperation,
typename SliceLengths,
typename DimAccessOrder,
index_t DstVectorDim,
......@@ -60,11 +60,11 @@ struct ThreadwiseTensorSliceTransfer_v1r4
const Dst0Desc& dst0_desc,
const Dst1Desc& dst1_desc,
const Index& dst_slice_origin_idx,
const SrcElementwiseOperation src_element_op)
const DstElementwiseOperation& dst_element_op)
: dst_coord_(make_tensor_coordinate(dst_desc, dst_slice_origin_idx)),
dst0_coord_(make_tensor_coordinate(dst0_desc, dst_slice_origin_idx)),
dst1_coord_(make_tensor_coordinate(dst1_desc, dst_slice_origin_idx)),
src_element_op_{src_element_op}
dst_element_op_{dst_element_op}
{
static_assert(SrcDesc::IsKnownAtCompileTime(),
"wrong! SrcDesc need to known at compile-time");
......@@ -258,15 +258,45 @@ struct ThreadwiseTensorSliceTransfer_v1r4
using dst_vector_t =
typename vector_type_maker<DstData, DstScalarPerVector>::type::type;
// copy data from src_buf into dst_vector
static_for<0, DstScalarPerVector, 1>{}([&](auto i) {
constexpr index_t src_offset = src_desc.CalculateOffset(
src_slice_origin_idx + dst_data_idx + i * dst_scalar_step_in_vector);
// apply element-wise operation and type convert
dst_vector.template AsType<DstData>()(i) =
type_convert<DstData>(src_element_op_(src_buf[Number<src_offset>{}]));
});
// load dst0 and dst1 and apply elementwise operation
{
// WARNING!!!!!!: this logic is only correct if DstScalarPerVector=1
// TODO: fix this
static_assert(DstScalarPerVector == 1, "wrong!");
// copy data from src_buf into dst_vector_src_data
constexpr index_t src_offset =
src_desc.CalculateOffset(src_slice_origin_idx + dst_data_idx);
const SrcData src_v = src_buf[Number<src_offset>{}];
// load dst0 and dst1
const bool is_dst0_valid =
coordinate_has_valid_offset_assuming_visible_index_is_valid(dst0_desc,
dst0_coord_);
const bool is_dst1_valid =
coordinate_has_valid_offset_assuming_visible_index_is_valid(dst1_desc,
dst1_coord_);
const DstData dst0_v =
dst0_buf.template Get<DstData>(dst0_coord_.GetOffset(), is_dst0_valid);
const DstData dst1_v =
dst1_buf.template Get<DstData>(dst1_coord_.GetOffset(), is_dst1_valid);
#if !CK_WORKAROUND_SWDEV_XXXXXX_THREAD_WISE_COPY_V1R4_TYPE_CONVERT_ISSUE
// apply element-wise operation in SrcData type
const SrcData dst_v = dst_element_op_(
src_v, type_convert<SrcData>(dst0_v), type_convert<SrcData>(dst1_v));
// apply type convert
dst_vector.template AsType<DstData>()(Number<0>{}) = type_convert<DstData>(dst_v);
#else
// apply element-wise operation in DstData type
const DstData dst_v = dst_element_op_(src_v, dst0_v, dst1_v);
dst_vector.template AsType<DstData>()(Number<0>{}) = dst_v;
#endif
}
const bool is_dst_valid =
coordinate_has_valid_offset_assuming_visible_index_is_valid(dst_desc, dst_coord_);
......@@ -327,11 +357,27 @@ struct ThreadwiseTensorSliceTransfer_v1r4
{
move_tensor_coordinate(
dst_desc, dst_coord_, dst_forward_steps[dim_access_order[i]]);
// dst0
move_tensor_coordinate(
dst0_desc, dst0_coord_, dst0_forward_steps[dim_access_order[i]]);
// dst1
move_tensor_coordinate(
dst1_desc, dst1_coord_, dst1_forward_steps[dim_access_order[i]]);
}
else
{
move_tensor_coordinate(
dst_desc, dst_coord_, dst_backward_steps[dim_access_order[i]]);
// dst0
move_tensor_coordinate(
dst0_desc, dst0_coord_, dst0_backward_steps[dim_access_order[i]]);
// dst1
move_tensor_coordinate(
dst1_desc, dst1_coord_, dst1_backward_steps[dim_access_order[i]]);
}
}
});
......@@ -469,7 +515,7 @@ struct ThreadwiseTensorSliceTransfer_v1r4
DstCoord dst_coord_;
Dst0Coord dst0_coord_;
Dst1Coord dst1_coord_;
SrcElementwiseOperation src_element_op_;
const DstElementwiseOperation dst_element_op_;
}; // namespace ck
} // namespace ck
......
......@@ -810,7 +810,7 @@ struct ThreadwiseTensorSliceTransfer_v3r2
SrcCoord src_coord_;
DstCoord dst_coord_;
SrcElementwiseOperation src_element_op_;
const SrcElementwiseOperation src_element_op_;
};
} // namespace ck
......
......@@ -136,6 +136,11 @@
#define CK_WORKAROUND_SWDEV_XXXXXX_INT8_DS_WRITE_ISSUE 1
#endif
// workaround for register spill due to compiler issue, when casting type between fp32 and fp16
#ifndef CK_WORKAROUND_SWDEV_XXXXXX_THREAD_WISE_COPY_V1R4_TYPE_CONVERT_ISSUE
#define CK_WORKAROUND_SWDEV_XXXXXX_THREAD_WISE_COPY_V1R4_TYPE_CONVERT_ISSUE 1
#endif
namespace ck {
enum InMemoryDataOperationEnum_t
......
......@@ -14,10 +14,6 @@
#include "device_base.hpp"
#include "example/2_gemm_xdl_bias_add/include/device_gemm_xdl_bias_add.hpp"
// C[m, n] = alpha(A[m, k] * B[k, n]) + beta * C0[m, n] + gamma * C1[m]
// assume C0 has same layout as C
// assume C1 is contiguous in memory
struct PassThrough
{
template <typename T>
......@@ -27,17 +23,60 @@ struct PassThrough
}
};
struct Relu
// GEMM Bias Add:
// C[m, n] = alpha(A[m, k] * B[k, n]) + beta * C0[m, n] + gamma * C1[m]
// assume C0 has same layout as C
// assume C1 is contiguous in memory
// C1 presents in memory as 1d vector, but is represented as 2D matrix C1[m, n], with stride = 0 in
// the "n" dimension
//
// alpha * v0 + beta * v1 + gamma * v2
// v0 is from C matrix
// v1 is from residual matrix
// v2 is from bias vector
struct BiasAdd
{
float alpha = 0.1;
#if 1
// correct result
// no scratch memory, good VGPR allocation (59)
// good perf (101Tflops)
template <typename T1, typename T2>
__host__ __device__ constexpr float operator()(float v0, T1 v1, T2 v2) const
{
// compiler seems very volatile to the order of these calculation:
// compiler is very eager to read AccVgpr (v0) out prematurely, resulting in register
// over-allocation. Therefore, move v0 calculation to the very end
float a = T1(0.2) * v1 + T2(0.3) * v2;
float b = a + float(0.1) * v0;
// ReLU
template <typename T>
__host__ __device__ constexpr T operator()(T v) const
return b;
}
#elif 0
// correct result
// some scratch memory (68), large VGPR usage (126)
// very little perf drop (101Tflops)
__host__ __device__ constexpr auto operator()(float v0, ck::half_t v1, ck::half_t v2) const
{
T tmp = alpha * v;
return tmp > 0 ? tmp : 0;
return float(0.1) * v0 + ck::half_t(0.2) * v1 + ck::half_t(0.3) * v2;
}
#elif 0
// correct result
// some scratch memory (68 dword)
// some perf drop (94Tflops)
// fp64 instructions are used
__host__ __device__ constexpr auto operator()(float v0, ck::half_t v1, ck::half_t v2) const
{
return 0.1 * v0 + 0.2 * v1 + 0.3 * v2;
}
#elif 1
// wrong result
// lots of scratch memory
// huge perf drop
__host__ __device__ constexpr auto operator()(float v0, ck::half_t v1, ck::half_t v2) const
{
return float(0.1) * v0 + float(0.2) * v1 + float(0.3) * v2;
}
#endif
};
template <typename ADataType,
......@@ -125,13 +164,49 @@ struct DeviceGemmInstance<float,
// clang-format on
};
template <typename AType,
typename BType,
typename CType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation>
static void host_verify(const Tensor<AType>& a_m_k,
const Tensor<BType>& b_k_n,
Tensor<CType>& c_m_n,
const Tensor<CType>& c0_m_n,
const Tensor<CType>& c1_m_n,
const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op,
const CElementwiseOperation& c_element_op)
{
auto f_mk_kn_mn = [&](auto m, auto n) {
const int K = a_m_k.mDesc.GetLengths()[1];
double v = 0;
for(int k = 0; k < K; ++k)
{
v += static_cast<const double>(a_element_op(a_m_k(m, k))) *
static_cast<const double>(b_element_op(b_k_n(k, n)));
}
c_m_n(m, n) = c_element_op(
v, static_cast<const double>(c0_m_n(m, n)), static_cast<const double>(c1_m_n(m, n)));
};
make_ParallelTensorFunctor(f_mk_kn_mn,
c_m_n.mDesc.GetLengths()[0],
c_m_n.mDesc.GetLengths()[1])(std::thread::hardware_concurrency());
}
int main(int argc, char* argv[])
{
if(argc != 4)
if(argc != 10)
{
printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: run kernel # of times (>1)\n");
printf("arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC\n");
exit(0);
}
......@@ -140,18 +215,24 @@ int main(int argc, char* argv[])
const int nrepeat = std::stoi(argv[3]);
// GEMM shape
ck::index_t M = 3840;
ck::index_t N = 4096;
ck::index_t K = 4096;
ck::index_t M = std::stoi(argv[4]);
ck::index_t N = std::stoi(argv[5]);
ck::index_t K = std::stoi(argv[6]);
ck::index_t StrideA = 4096;
ck::index_t StrideB = 4096;
ck::index_t StrideC = 4096;
ck::index_t StrideA = std::stoi(argv[7]);
ck::index_t StrideB = std::stoi(argv[8]);
ck::index_t StrideC = std::stoi(argv[9]);
// matrix data type
#if 1
using ADataType = ck::half_t;
using BDataType = ck::half_t;
using CDataType = ck::half_t;
#else
using ADataType = float;
using BDataType = float;
using CDataType = float;
#endif
// matrix layout
using ALayout = ck::tensor_layout::gemm::RowMajor;
......@@ -219,6 +300,8 @@ int main(int argc, char* argv[])
c0_m_n_device_buf.ToDevice(c0_m_n.mData.data());
c1_m_n_device_buf.ToDevice(c1_m_n.mData.data());
auto c_element_op = BiasAdd{};
// do GEMM
auto gemm = typename DeviceGemmInstance<ADataType,
BDataType,
......@@ -228,7 +311,7 @@ int main(int argc, char* argv[])
CLayout,
PassThrough,
PassThrough,
Relu>::type{};
decltype(c_element_op)>::type{};
auto invoker = gemm.MakeInvoker();
auto argument = gemm.MakeArgument(static_cast<ADataType*>(a_m_k_device_buf.GetDeviceBuffer()),
......@@ -244,7 +327,7 @@ int main(int argc, char* argv[])
StrideC,
PassThrough{},
PassThrough{},
Relu{});
c_element_op);
if(!gemm.IsSupportedArgument(argument))
{
......@@ -270,8 +353,13 @@ int main(int argc, char* argv[])
if(do_verification)
{
host_gemm_mk_kn_mn(a_m_k, b_k_n, c_m_n_host_result, PassThrough{}, PassThrough{}, Relu{});
check_error(c_m_n_host_result, c_m_n_device_result);
host_verify(a_m_k,
b_k_n,
c_m_n_host_result,
c0_m_n,
c1_m_n,
PassThrough{},
PassThrough{},
c_element_op);
}
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment