Commit fd1cf141 authored by Jing Zhang's avatar Jing Zhang
Browse files

add fp32 out client example

parent abef7c4e
......@@ -26,7 +26,7 @@ using ADataType = F16;
using BDataType = F16;
using D0DataType = F32;
using DsDataType = ck::Tuple<D0DataType>;
using EDataType = F16;
using EDataType = F32;
using ALayout = Row;
using BLayout = Col;
......
......@@ -16,7 +16,7 @@ namespace tensor_operation {
namespace device {
namespace instance {
//fp16_output
// fp16_output
void add_device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f16_mk_kn_mn_instances(
std::vector<std::unique_ptr<DeviceGroupedGemmFixedNK<Row,
Row,
......@@ -43,7 +43,7 @@ void add_device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f16_mk_nk_mn_instances(
PassThrough,
AddBias>>>& instances);
//fp32_output
// fp32_output
void add_device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f32_mk_kn_mn_instances(
std::vector<std::unique_ptr<DeviceGroupedGemmFixedNK<Row,
Row,
......@@ -70,10 +70,6 @@ void add_device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f32_mk_nk_mn_instances(
PassThrough,
AddBias>>>& instances);
template <typename ALayout,
typename BLayout,
typename ELayout,
......@@ -109,7 +105,7 @@ struct DeviceOperationInstanceFactory<
{
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
//fp16_output
// fp16_output
if constexpr(is_same_v<ADataType, half_t> && is_same_v<BDataType, half_t> &&
is_same_v<EDataType, half_t>)
{
......@@ -125,7 +121,7 @@ struct DeviceOperationInstanceFactory<
}
}
//fp32_output
// fp32_output
if constexpr(is_same_v<ADataType, half_t> && is_same_v<BDataType, half_t> &&
is_same_v<EDataType, float>)
{
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment