Commit fd1cf141 authored by Jing Zhang's avatar Jing Zhang
Browse files

add fp32 out client example

parent abef7c4e
...@@ -26,7 +26,7 @@ using ADataType = F16; ...@@ -26,7 +26,7 @@ using ADataType = F16;
using BDataType = F16; using BDataType = F16;
using D0DataType = F32; using D0DataType = F32;
using DsDataType = ck::Tuple<D0DataType>; using DsDataType = ck::Tuple<D0DataType>;
using EDataType = F16; using EDataType = F32;
using ALayout = Row; using ALayout = Row;
using BLayout = Col; using BLayout = Col;
......
...@@ -16,7 +16,7 @@ namespace tensor_operation { ...@@ -16,7 +16,7 @@ namespace tensor_operation {
namespace device { namespace device {
namespace instance { namespace instance {
//fp16_output // fp16_output
void add_device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f16_mk_kn_mn_instances( void add_device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f16_mk_kn_mn_instances(
std::vector<std::unique_ptr<DeviceGroupedGemmFixedNK<Row, std::vector<std::unique_ptr<DeviceGroupedGemmFixedNK<Row,
Row, Row,
...@@ -43,7 +43,7 @@ void add_device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f16_mk_nk_mn_instances( ...@@ -43,7 +43,7 @@ void add_device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f16_mk_nk_mn_instances(
PassThrough, PassThrough,
AddBias>>>& instances); AddBias>>>& instances);
//fp32_output // fp32_output
void add_device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f32_mk_kn_mn_instances( void add_device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f32_mk_kn_mn_instances(
std::vector<std::unique_ptr<DeviceGroupedGemmFixedNK<Row, std::vector<std::unique_ptr<DeviceGroupedGemmFixedNK<Row,
Row, Row,
...@@ -70,10 +70,6 @@ void add_device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f32_mk_nk_mn_instances( ...@@ -70,10 +70,6 @@ void add_device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f32_mk_nk_mn_instances(
PassThrough, PassThrough,
AddBias>>>& instances); AddBias>>>& instances);
template <typename ALayout, template <typename ALayout,
typename BLayout, typename BLayout,
typename ELayout, typename ELayout,
...@@ -109,7 +105,7 @@ struct DeviceOperationInstanceFactory< ...@@ -109,7 +105,7 @@ struct DeviceOperationInstanceFactory<
{ {
std::vector<std::unique_ptr<DeviceOp>> op_ptrs; std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
//fp16_output // fp16_output
if constexpr(is_same_v<ADataType, half_t> && is_same_v<BDataType, half_t> && if constexpr(is_same_v<ADataType, half_t> && is_same_v<BDataType, half_t> &&
is_same_v<EDataType, half_t>) is_same_v<EDataType, half_t>)
{ {
...@@ -125,7 +121,7 @@ struct DeviceOperationInstanceFactory< ...@@ -125,7 +121,7 @@ struct DeviceOperationInstanceFactory<
} }
} }
//fp32_output // fp32_output
if constexpr(is_same_v<ADataType, half_t> && is_same_v<BDataType, half_t> && if constexpr(is_same_v<ADataType, half_t> && is_same_v<BDataType, half_t> &&
is_same_v<EDataType, float>) is_same_v<EDataType, float>)
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment