Commit 0c359ae8 authored by danyao12's avatar danyao12
Browse files

ZDataType can be set to U16/INT32 in fwd&bwd&train examples

parent 3b57967f
......@@ -55,6 +55,7 @@ using F16 = ck::half_t;
using BF16 = ck::bhalf_t;
using F32 = float;
using U16 = unsigned short;
using INT32 = int32_t;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using Scale = ck::tensor_operation::element_wise::Scale;
......@@ -68,7 +69,7 @@ using GemmDataType = BF16;
using AccDataType = F32;
using ShuffleDataType = F32;
using LSEDataType = F32;
using ZDataType = U16;
using ZDataType = U16; // INT32
using Acc0BiasDataType = ck::Tuple<>;
using Acc1BiasDataType = ck::Tuple<>;
......@@ -422,7 +423,7 @@ using ReferenceGemm1GradInstance = ck::tensor_operation::host::ReferenceBatchedG
// Ref dropout
using ReferenceDropoutInstance =
ck::tensor_operation::host::ReferenceDropout<ushort, InputDataType, InputDataType>;
ck::tensor_operation::host::ReferenceDropout<ZDataType, InputDataType, InputDataType>;
template <typename TensorQ,
typename TensorK,
......@@ -442,7 +443,7 @@ void run_attention_fwd_host(const TensorQ& q_g_m_k,
TensorLSE& lse_g_m,
TensorP& p_drop_g_m_n,
TensorZ& z_g_m_n,
ushort p_dropout_in_16bits,
ZDataType p_dropout_in_16bits,
float rp_dropout)
{
// S = alpha * Q * K^T
......@@ -549,7 +550,7 @@ int run(int argc, char* argv[])
}
float p_dropout = 1 - p_drop;
uint16_t p_dropout_in_16bits = uint16_t(std::floor(p_dropout * 65535.0));
ZDataType p_dropout_in_16bits = ZDataType(std::floor(p_dropout * 65535.0));
float rp_dropout = 1.0 / p_dropout;
float alpha = 1.f / std::sqrt(K);
......
......@@ -38,6 +38,7 @@ using F16 = ck::half_t;
using BF16 = ck::bhalf_t;
using F32 = float;
using U16 = unsigned short;
using INT32 = int32_t;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
......@@ -49,7 +50,7 @@ using B1DataType = DataType;
using AccDataType = F32;
using CShuffleDataType = F32;
using CDataType = DataType;
using ZDataType = U16;
using ZDataType = U16; // INT32
using LSEDataType = F32;
using Acc0BiasDataType = ck::Tuple<>;
using Acc1BiasDataType = ck::Tuple<>;
......
......@@ -64,6 +64,7 @@ using F16 = ck::half_t;
using BF16 = ck::bhalf_t;
using F32 = float;
using U16 = unsigned short;
using INT32 = int32_t;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using Scale = ck::tensor_operation::element_wise::Scale;
......@@ -71,13 +72,13 @@ using Scale = ck::tensor_operation::element_wise::Scale;
using QKVElementOp = PassThrough;
using YElementOp = PassThrough;
using InputDataType = BF16;
using OutputDataType = F32;
using GemmDataType = BF16;
using InputDataType = F16;
using OutputDataType = F16;
using GemmDataType = F16;
using AccDataType = F32;
using ShuffleDataType = F32;
using LSEDataType = F32;
using ZDataType = U16;
using ZDataType = INT32; // INT32
using Acc0BiasDataType = ck::Tuple<>;
using Acc1BiasDataType = ck::Tuple<>;
......@@ -641,7 +642,7 @@ using ReferenceGemm1GradInstance = ck::tensor_operation::host::ReferenceBatchedG
// Ref dropout
using ReferenceDropoutInstance =
ck::tensor_operation::host::ReferenceDropout<ushort, InputDataType, InputDataType>;
ck::tensor_operation::host::ReferenceDropout<ZDataType, InputDataType, InputDataType>;
template <typename TensorQ,
typename TensorK,
......@@ -661,7 +662,7 @@ void run_attention_fwd_host(const TensorQ& q_g_m_k,
TensorLSE& lse_g_m,
TensorP& p_drop_g_m_n,
TensorZ& z_g_m_n,
ushort p_dropout_in_16bits,
ZDataType p_dropout_in_16bits,
float rp_dropout)
{
// S = alpha * Q * K^T
......@@ -768,7 +769,7 @@ int run(int argc, char* argv[])
}
float p_dropout = 1 - p_drop;
uint16_t p_dropout_in_16bits = uint16_t(std::floor(p_dropout * 65535.0));
ZDataType p_dropout_in_16bits = ZDataType(std::floor(p_dropout * 65535.0));
float rp_dropout = 1.0 / p_dropout;
float alpha = 1.f / std::sqrt(K);
......
......@@ -54,6 +54,7 @@ using F16 = ck::half_t;
using BF16 = ck::bhalf_t;
using F32 = float;
using U16 = unsigned short;
using INT32 = int32_t;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using Scale = ck::tensor_operation::element_wise::Scale;
......@@ -67,7 +68,7 @@ using GemmDataType = BF16;
using AccDataType = F32;
using ShuffleDataType = F32;
using LSEDataType = F32;
using ZDataType = U16;
using ZDataType = INT32; // U16
using Acc0BiasDataType = ck::Tuple<>;
using Acc1BiasDataType = ck::Tuple<>;
......@@ -421,7 +422,7 @@ using ReferenceGemm1GradInstance = ck::tensor_operation::host::ReferenceBatchedG
// Ref dropout
using ReferenceDropoutInstance =
ck::tensor_operation::host::ReferenceDropout<ushort, InputDataType, InputDataType>;
ck::tensor_operation::host::ReferenceDropout<ZDataType, InputDataType, InputDataType>;
template <typename TensorQ,
typename TensorK,
......@@ -441,7 +442,7 @@ void run_attention_fwd_host(const TensorQ& q_g_m_k,
TensorLSE& lse_g_m,
TensorP& p_drop_g_m_n,
TensorZ& z_g_m_n,
ushort p_dropout_in_16bits,
ZDataType p_dropout_in_16bits,
float rp_dropout)
{
// S = alpha * Q * K^T
......@@ -536,7 +537,7 @@ int run(int argc, char* argv[])
}
float p_dropout = 1 - p_drop;
uint16_t p_dropout_in_16bits = uint16_t(std::floor(p_dropout * 65535.0));
ZDataType p_dropout_in_16bits = ZDataType(std::floor(p_dropout * 65535.0));
float rp_dropout = 1.0 / p_dropout;
auto gemm = DeviceGemmInstance{};
......
......@@ -38,6 +38,7 @@ using F16 = ck::half_t;
using BF16 = ck::bhalf_t;
using F32 = float;
using U16 = unsigned short;
using INT32 = int32_t;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
......@@ -49,7 +50,7 @@ using B1DataType = DataType;
using AccDataType = F32;
using CShuffleDataType = F32;
using CDataType = DataType;
using ZDataType = U16;
using ZDataType = INT32; // U16
using LSEDataType = F32;
using Acc0BiasDataType = ck::Tuple<>;
using Acc1BiasDataType = ck::Tuple<>;
......
......@@ -63,6 +63,7 @@ using F16 = ck::half_t;
using BF16 = ck::bhalf_t;
using F32 = float;
using U16 = unsigned short;
using INT32 = int32_t;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using Scale = ck::tensor_operation::element_wise::Scale;
......@@ -76,7 +77,7 @@ using GemmDataType = BF16;
using AccDataType = F32;
using ShuffleDataType = F32;
using LSEDataType = F32;
using ZDataType = U16;
using ZDataType = INT32; // U16
using Acc0BiasDataType = ck::Tuple<>;
using Acc1BiasDataType = ck::Tuple<>;
......@@ -640,7 +641,7 @@ using ReferenceGemm1GradInstance = ck::tensor_operation::host::ReferenceBatchedG
// Ref dropout
using ReferenceDropoutInstance =
ck::tensor_operation::host::ReferenceDropout<ushort, InputDataType, InputDataType>;
ck::tensor_operation::host::ReferenceDropout<ZDataType, InputDataType, InputDataType>;
template <typename TensorQ,
typename TensorK,
......@@ -660,7 +661,7 @@ void run_attention_fwd_host(const TensorQ& q_g_m_k,
TensorLSE& lse_g_m,
TensorP& p_drop_g_m_n,
TensorZ& z_g_m_n,
ushort p_dropout_in_16bits,
ZDataType p_dropout_in_16bits,
float rp_dropout)
{
// S = alpha * Q * K^T
......@@ -755,7 +756,7 @@ int run(int argc, char* argv[])
}
float p_dropout = 1 - p_drop;
uint16_t p_dropout_in_16bits = uint16_t(std::floor(p_dropout * 65535.0));
ZDataType p_dropout_in_16bits = ZDataType(std::floor(p_dropout * 65535.0));
float rp_dropout = 1.0 / p_dropout;
auto gemm_fwd = DeviceGemmInstanceFWD{};
......
......@@ -257,5 +257,48 @@ check_err(const Range& out, const RefRange& ref, unsigned short atol = 1)
return res;
}
template <typename Range, typename RefRange>
typename std::enable_if<
std::is_same_v<ranges::range_value_t<Range>, ranges::range_value_t<RefRange>> &&
std::is_same_v<ranges::range_value_t<Range>, int32_t>,
bool>::type
check_err(const Range& out, const RefRange& ref, int32_t atol = 1)
{
const std::string& msg = "Error: Incorrect U16 results!";
if(out.size() != ref.size())
{
std::cerr << msg << " out.size() != ref.size(), :" << out.size() << " != " << ref.size()
<< std::endl;
return false;
}
bool res{true};
int err_count = 0;
int32_t err = 0;
int32_t max_err = std::numeric_limits<int32_t>::min();
for(std::size_t i = 0; i < ref.size(); ++i)
{
const int32_t o = *std::next(std::begin(out), i);
const int32_t r = *std::next(std::begin(ref), i);
err = (o > r) ? o - r : r - o;
if(err > atol)
{
max_err = err > max_err ? err : max_err;
err_count++;
if(err_count < 5)
{
std::cerr << msg << std::setw(12) << " out[" << i << "] != ref[" << i << "]: " << o
<< " != " << r << std::endl;
}
res = false;
}
}
if(!res)
{
std::cerr << std::setw(12) << "max err: " << max_err << std::endl;
}
return res;
}
} // namespace utils
} // namespace ck
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment