Commit c9915508 authored by letaoqin's avatar letaoqin
Browse files

fix group d data type

parent 98df59c6
...@@ -48,12 +48,11 @@ using ADataType = DataType; ...@@ -48,12 +48,11 @@ using ADataType = DataType;
using B0DataType = DataType; using B0DataType = DataType;
using B1DataType = DataType; using B1DataType = DataType;
using AccDataType = F32; using AccDataType = F32;
using DDataType = F16;
using CShuffleDataType = F32; using CShuffleDataType = F32;
using CDataType = DataType; using CDataType = DataType;
using ZDataType = U16; // INT32 using ZDataType = U16; // INT32
using LSEDataType = F32; using LSEDataType = F32;
using Acc0BiasDataType = DDataType; using Acc0BiasDataType = F16;
using Acc1BiasDataType = void; using Acc1BiasDataType = void;
static constexpr ck::index_t NumDimG = 2; static constexpr ck::index_t NumDimG = 2;
......
...@@ -67,7 +67,7 @@ int run(int argc, char* argv[]) ...@@ -67,7 +67,7 @@ int run(int argc, char* argv[])
std::vector<Tensor<B0DataType>> b0_tensors; std::vector<Tensor<B0DataType>> b0_tensors;
std::vector<Tensor<B1DataType>> b1_tensors; std::vector<Tensor<B1DataType>> b1_tensors;
std::vector<Tensor<CDataType>> c_tensors; std::vector<Tensor<CDataType>> c_tensors;
std::vector<Tensor<DDataType>> d_tensors; std::vector<Tensor<Acc0BiasDataType>> d_tensors;
std::vector<Tensor<ZDataType>> z_tensors; std::vector<Tensor<ZDataType>> z_tensors;
std::vector<Tensor<LSEDataType>> lse_tensors; std::vector<Tensor<LSEDataType>> lse_tensors;
...@@ -157,7 +157,7 @@ int run(int argc, char* argv[]) ...@@ -157,7 +157,7 @@ int run(int argc, char* argv[])
Tensor<B0DataType> b0_gs_ns_ks(b0_gs_ns_ks_lengths, b0_gs_ns_ks_strides); Tensor<B0DataType> b0_gs_ns_ks(b0_gs_ns_ks_lengths, b0_gs_ns_ks_strides);
Tensor<B1DataType> b1_gs_os_ns(b1_gs_os_ns_lengths, b1_gs_os_ns_strides); Tensor<B1DataType> b1_gs_os_ns(b1_gs_os_ns_lengths, b1_gs_os_ns_strides);
Tensor<CDataType> c_gs_ms_os_device_result(c_gs_ms_os_lengths, c_gs_ms_os_strides); Tensor<CDataType> c_gs_ms_os_device_result(c_gs_ms_os_lengths, c_gs_ms_os_strides);
Tensor<DDataType> d_gs_ms_ns(d_gs_ms_ns_lengths, d_gs_ms_ns_strides); Tensor<Acc0BiasDataType> d_gs_ms_ns(d_gs_ms_ns_lengths, d_gs_ms_ns_strides);
Tensor<ZDataType> z_gs_ms_ns(z_gs_ms_ns_lengths, z_gs_ms_ns_strides); Tensor<ZDataType> z_gs_ms_ns(z_gs_ms_ns_lengths, z_gs_ms_ns_strides);
Tensor<LSEDataType> lse_gs_ms_device_result(lse_gs_ms_lengths, lse_gs_ms_strides); Tensor<LSEDataType> lse_gs_ms_device_result(lse_gs_ms_lengths, lse_gs_ms_strides);
...@@ -165,7 +165,7 @@ int run(int argc, char* argv[]) ...@@ -165,7 +165,7 @@ int run(int argc, char* argv[])
flop += (size_t(M) * N * K * 2 + size_t(M) * N * O * 2) * Batch; flop += (size_t(M) * N * K * 2 + size_t(M) * N * O * 2) * Batch;
num_byte += (sizeof(ADataType) * M * K + sizeof(B0DataType) * K * N + num_byte += (sizeof(ADataType) * M * K + sizeof(B0DataType) * K * N +
sizeof(B1DataType) * N * O + sizeof(CDataType) * M * O + sizeof(B1DataType) * N * O + sizeof(CDataType) * M * O +
sizeof(DDataType) * M * N * (std::is_void<Acc0BiasDataType>::value ? 0 : 1)) * sizeof(Acc0BiasDataType) * M * N * (std::is_void<Acc0BiasDataType>::value ? 0 : 1)) *
Batch; Batch;
if(i < 4) if(i < 4)
...@@ -189,25 +189,25 @@ int run(int argc, char* argv[]) ...@@ -189,25 +189,25 @@ int run(int argc, char* argv[])
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2<ADataType>{-2, 2}); a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2<ADataType>{-2, 2});
b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2}); b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2});
b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_2<B1DataType>{-2, 2}); b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_2<B1DataType>{-2, 2});
d_gs_ms_ns.GenerateTensorValue(GeneratorTensor_2<DDataType>{-1, 1}); d_gs_ms_ns.GenerateTensorValue(GeneratorTensor_2<Acc0BiasDataType>{-1, 1});
break; break;
case 2: case 2:
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0}); a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_3<B0DataType>{0.0, 1.0}); b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_3<B0DataType>{0.0, 1.0});
b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_3<B1DataType>{-0.5, 0.5}); b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_3<B1DataType>{-0.5, 0.5});
d_gs_ms_ns.GenerateTensorValue(GeneratorTensor_3<DDataType>{-0.5, 0.5}); d_gs_ms_ns.GenerateTensorValue(GeneratorTensor_3<Acc0BiasDataType>{-0.5, 0.5});
break; break;
case 3: case 3:
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2<ADataType>{-2, 2}); a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2<ADataType>{-2, 2});
b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<B0DataType>{}); b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<B0DataType>{});
b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<B1DataType>{}); b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<B1DataType>{});
d_gs_ms_ns.GenerateTensorValue(GeneratorTensor_1<DDataType>{1}); d_gs_ms_ns.GenerateTensorValue(GeneratorTensor_1<Acc0BiasDataType>{1});
break; break;
default: default:
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1<ADataType>{1}); a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1<ADataType>{1});
b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Sequential<1>{}); b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Sequential<1>{});
b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<B1DataType>{}); b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<B1DataType>{});
d_gs_ms_ns.GenerateTensorValue(GeneratorTensor_1<DDataType>{1}); d_gs_ms_ns.GenerateTensorValue(GeneratorTensor_1<Acc0BiasDataType>{1});
} }
a_tensors.push_back(a_gs_ms_ks); a_tensors.push_back(a_gs_ms_ks);
...@@ -227,7 +227,7 @@ int run(int argc, char* argv[]) ...@@ -227,7 +227,7 @@ int run(int argc, char* argv[])
c_tensors_device.emplace_back(std::make_unique<DeviceMem>( c_tensors_device.emplace_back(std::make_unique<DeviceMem>(
sizeof(CDataType) * c_gs_ms_os_device_result.mDesc.GetElementSpaceSize())); sizeof(CDataType) * c_gs_ms_os_device_result.mDesc.GetElementSpaceSize()));
d_tensors_device.emplace_back(std::make_unique<DeviceMem>( d_tensors_device.emplace_back(std::make_unique<DeviceMem>(
sizeof(DDataType) * d_gs_ms_ns.mDesc.GetElementSpaceSize())); sizeof(Acc0BiasDataType) * d_gs_ms_ns.mDesc.GetElementSpaceSize()));
z_tensors_device.emplace_back(std::make_unique<DeviceMem>( z_tensors_device.emplace_back(std::make_unique<DeviceMem>(
sizeof(ZDataType) * z_gs_ms_ns.mDesc.GetElementSpaceSize())); sizeof(ZDataType) * z_gs_ms_ns.mDesc.GetElementSpaceSize()));
lse_tensors_device.emplace_back(std::make_unique<DeviceMem>( lse_tensors_device.emplace_back(std::make_unique<DeviceMem>(
...@@ -359,7 +359,7 @@ int run(int argc, char* argv[]) ...@@ -359,7 +359,7 @@ int run(int argc, char* argv[])
Tensor<B0DataType> b0_g_k_n({G0 * G1, K, N}); Tensor<B0DataType> b0_g_k_n({G0 * G1, K, N});
Tensor<B1DataType> b1_g_n_o({G0 * G1, N, O}); Tensor<B1DataType> b1_g_n_o({G0 * G1, N, O});
Tensor<AccDataType> acc0_g_m_n({G0 * G1, M, N}); // scratch object after gemm0 Tensor<AccDataType> acc0_g_m_n({G0 * G1, M, N}); // scratch object after gemm0
Tensor<AccDataType> d_g_m_n({G0 * G1, M, N}); Tensor<Acc0BiasDataType> d_g_m_n({G0 * G1, M, N});
Tensor<ADataType> a1_g_m_n({G0 * G1, M, N}); // scratch object after softmax Tensor<ADataType> a1_g_m_n({G0 * G1, M, N}); // scratch object after softmax
Tensor<ADataType> a1_g_m_n_drop({G0 * G1, M, N}); // scratch object after softmax Tensor<ADataType> a1_g_m_n_drop({G0 * G1, M, N}); // scratch object after softmax
Tensor<CDataType> c_g_m_o_host_result({G0 * G1, M, O}); // scratch object after gemm1 Tensor<CDataType> c_g_m_o_host_result({G0 * G1, M, O}); // scratch object after gemm1
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment