"example/vscode:/vscode.git/clone" did not exist on "12235112a10ecbe47acead9a03564cb42c4624c2"
Commit 5e98fc5b authored by Jing Zhang's avatar Jing Zhang
Browse files

fixed fp8 init; and reference gemm

parent 37a8c1f7
......@@ -247,7 +247,7 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
// GridwiseGemm
using GridwiseGemm = GridwiseGemmMultipleD_xdl_cshuffle<
ADataType, // TODO: distinguish A/B datatype
ADataType,
BDataType,
ComputeDataType,
AccDataType,
......
......@@ -27,6 +27,12 @@ struct PassThrough
y = x;
}
template <>
__host__ __device__ void operator()<float, double>(float& y, const double& x) const
{
y = x;
}
template <>
__host__ __device__ void operator()<float, float>(float& y, const float& x) const
{
......@@ -412,14 +418,14 @@ struct Swish
{
Swish(float beta = 1.0f) : beta_(beta) {}
template <typename T>
__host__ __device__ void operator()(T& y, const T& x) const
template <typename Y, typename X>
__host__ __device__ void operator()(Y& y, const X& x) const
{
static_assert(is_same<T, float>::value || is_same<T, double>::value ||
is_same<T, ck::half_t>::value,
static_assert(is_same<X, float>::value || is_same<X, double>::value ||
is_same<X, ck::half_t>::value,
"Data type is not supported by this operation!");
y = x / (ck::type_convert<T>(1) + ck::math::exp(-beta_ * x));
y = x / (ck::type_convert<Y>(1) + ck::math::exp(-beta_ * x));
};
float beta_ = 1.0f;
......
......@@ -137,13 +137,12 @@ struct ThreadwiseTensorSliceTransfer_v1r3
constexpr index_t src_offset = src_desc.CalculateOffset(
src_slice_origin_idx + idx_md + i * dst_scalar_step_in_vector);
SrcData v;
DstData v;
// apply element-wise operation
element_op_(v, src_buf[Number<src_offset>{}]);
// apply type convert
dst_vector.template AsType<DstData>()(i) = type_convert<DstData>(v);
dst_vector.template AsType<DstData>()(i) = v;
});
const bool is_dst_valid =
......@@ -1289,13 +1288,13 @@ struct ThreadwiseTensorSliceTransfer_StaticToStatic
constexpr index_t dst_offset = dst_desc.CalculateOffset(
dst_slice_origin_idx + idx_md + i * dst_scalar_step_in_vector);
SrcData v;
DstData v;
// apply element-wise operation
element_op_(v, src_buf[Number<src_offset>{}]);
// apply type convert
dst_buf(Number<dst_offset>{}) = type_convert<DstData>(v);
dst_buf(Number<dst_offset>{}) = v;
});
});
}
......
......@@ -20,7 +20,8 @@ template <typename ADataType,
typename AccDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation>
typename CElementwiseOperation,
typename ComputeType = ADataType>
struct ReferenceGemm : public device::BaseOperator
{
// Argument
......@@ -64,8 +65,8 @@ struct ReferenceGemm : public device::BaseOperator
for(int k = 0; k < K; ++k)
{
ADataType v_a;
BDataType v_b;
ComputeType v_a;
ComputeType v_b;
// use PassThrough instead of ConvertBF16RTN for reference calculation
if constexpr(is_same_v<AElementwiseOperation,
......@@ -88,8 +89,7 @@ struct ReferenceGemm : public device::BaseOperator
arg.b_element_op_(v_b, arg.b_k_n_(k, n));
}
v_acc +=
ck::type_convert<AccDataType>(v_a) * ck::type_convert<AccDataType>(v_b);
v_acc += type_convert<AccDataType>(v_a * v_b);
}
CDataType v_c;
......
......@@ -55,6 +55,18 @@ struct GeneratorTensor_1<int8_t>
}
};
template <>
struct GeneratorTensor_1<ck::f8_t>
{
float value = 1.0;
template <typename... Is>
ck::f8_t operator()(Is...)
{
return ck::type_convert<ck::f8_t>(value);
}
};
template <typename T>
struct GeneratorTensor_2
{
......
......@@ -83,8 +83,8 @@ bool profile_gemm_multiply_add_impl(int do_verification,
d1_m_n.GenerateTensorValue(GeneratorTensor_2<D1DataType>{-1, 1});
break;
default:
a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
b_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 0.2});
b_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.1, 0.1});
d0_m_n.GenerateTensorValue(GeneratorTensor_3<D0DataType>{0.0, 1.0});
d1_m_n.GenerateTensorValue(GeneratorTensor_3<D1DataType>{0.0, 1.0});
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment