Commit 2b27d5fc authored by Chao Liu's avatar Chao Liu
Browse files

Merge remote-tracking branch 'origin/develop' into rosenrodt/gemm-layernorm

parents f689a155 fa9a0a5c
...@@ -236,9 +236,14 @@ template <typename SrcData, ...@@ -236,9 +236,14 @@ template <typename SrcData,
index_t SrcScalarPerVector, index_t SrcScalarPerVector,
index_t SrcScalarStrideInVector, index_t SrcScalarStrideInVector,
bool SrcResetCoordinateAfterRun, bool SrcResetCoordinateAfterRun,
bool InvalidElementAsNaN = false,
typename enable_if<DstDesc::IsKnownAtCompileTime(), bool>::type = false> typename enable_if<DstDesc::IsKnownAtCompileTime(), bool>::type = false>
struct ThreadwiseTensorSliceTransfer_v2 struct ThreadwiseTensorSliceTransfer_v2
{ {
static_assert((InvalidElementAsNaN && !std::is_integral<DstData>::value) ||
(!InvalidElementAsNaN),
"Filling invalid element as NaN is only for floating point types");
static constexpr index_t nDim = SliceLengths::Size(); static constexpr index_t nDim = SliceLengths::Size();
using Index = MultiIndex<nDim>; using Index = MultiIndex<nDim>;
...@@ -318,8 +323,18 @@ struct ThreadwiseTensorSliceTransfer_v2 ...@@ -318,8 +323,18 @@ struct ThreadwiseTensorSliceTransfer_v2
dst_desc.CalculateOffset(to_multi_index(dst_slice_origin_idx) + src_data_idx + dst_desc.CalculateOffset(to_multi_index(dst_slice_origin_idx) + src_data_idx +
i * src_scalar_step_in_vector); i * src_scalar_step_in_vector);
if constexpr(InvalidElementAsNaN)
{
dst_buf(Number<dst_offset>{}) =
is_src_valid
? type_convert<DstData>(src_vector.template AsType<SrcData>()[i])
: NumericLimits<DstData>::QuietNaN();
}
else
{
dst_buf(Number<dst_offset>{}) = dst_buf(Number<dst_offset>{}) =
type_convert<DstData>(src_vector.template AsType<SrcData>()[i]); type_convert<DstData>(src_vector.template AsType<SrcData>()[i]);
}
}); });
if constexpr(idx_1d.value != num_access - 1) if constexpr(idx_1d.value != num_access - 1)
......
...@@ -148,6 +148,8 @@ __host__ __device__ constexpr auto min(X x, Ys... ys) ...@@ -148,6 +148,8 @@ __host__ __device__ constexpr auto min(X x, Ys... ys)
template <typename T> template <typename T>
__device__ T exp(T x); __device__ T exp(T x);
// TODO: add f16 support using v_exp_f16
template <> template <>
__device__ float exp<float>(float x) __device__ float exp<float>(float x)
{ {
......
...@@ -17,7 +17,7 @@ struct AccumulateWithNanIgnore ...@@ -17,7 +17,7 @@ struct AccumulateWithNanIgnore
{ {
__device__ static inline void Calculate(AccDataType& accuVal, AccDataType currVal) __device__ static inline void Calculate(AccDataType& accuVal, AccDataType currVal)
{ {
if(!isnan(currVal)) if(!ck::math::isnan(currVal))
{ {
ReduceOperation{}(accuVal, currVal); ReduceOperation{}(accuVal, currVal);
} }
......
add_subdirectory(src/host_tensor)
add_subdirectory(src/tensor_operation_instance/gpu) add_subdirectory(src/tensor_operation_instance/gpu)
add_subdirectory(src/host_tensor)
add_subdirectory(src/utility) add_subdirectory(src/utility)
...@@ -382,13 +382,8 @@ HostTensorDescriptor::HostTensorDescriptor(const std::vector<X>& lens, ...@@ -382,13 +382,8 @@ HostTensorDescriptor::HostTensorDescriptor(const std::vector<X>& lens,
{ {
} }
void ostream_HostTensorDescriptor(const HostTensorDescriptor& desc, std::ostream& os = std::cout);
#if 1 #if 1
// FIXME: remove // FIXME: remove
void bf16_to_f32_(const Tensor<ck::bhalf_t>& src, Tensor<float>& dst);
#endif
template <typename T> template <typename T>
float check_error(const Tensor<T>& ref, const Tensor<T>& result) float check_error(const Tensor<T>& ref, const Tensor<T>& result)
{ {
...@@ -434,3 +429,4 @@ float check_error(const Tensor<T>& ref, const Tensor<T>& result) ...@@ -434,3 +429,4 @@ float check_error(const Tensor<T>& ref, const Tensor<T>& result)
return linf_error; return linf_error;
} }
#endif
...@@ -62,20 +62,20 @@ struct ReferenceBatchedGemm : public device::BaseOperator ...@@ -62,20 +62,20 @@ struct ReferenceBatchedGemm : public device::BaseOperator
for(int k = 0; k < K; ++k) for(int k = 0; k < K; ++k)
{ {
float v_a; ADataType v_a;
float v_b; BDataType v_b;
arg.a_element_op_(v_a, static_cast<const float>(arg.a_g_m_k_(g, m, k))); arg.a_element_op_(v_a, arg.a_g_m_k_(g, m, k));
arg.b_element_op_(v_b, static_cast<const float>(arg.b_g_k_n_(g, k, n))); arg.b_element_op_(v_b, arg.b_g_k_n_(g, k, n));
v_acc += v_a * v_b; v_acc += ck::type_convert<float>(v_a) * ck::type_convert<float>(v_b);
} }
float v_c; float v_c;
arg.c_element_op_(v_c, v_acc); arg.c_element_op_(v_c, v_acc);
arg.c_g_m_n_(g, m, n) = v_c; arg.c_g_m_n_(g, m, n) = ck::type_convert<CDataType>(v_c);
}; };
make_ParallelTensorFunctor(f_gmk_gkn_gmn, make_ParallelTensorFunctor(f_gmk_gkn_gmn,
......
...@@ -63,20 +63,21 @@ struct ReferenceGemm : public device::BaseOperator ...@@ -63,20 +63,21 @@ struct ReferenceGemm : public device::BaseOperator
for(int k = 0; k < K; ++k) for(int k = 0; k < K; ++k)
{ {
AccDataType v_a; ADataType v_a;
AccDataType v_b; BDataType v_b;
arg.a_element_op_(v_a, static_cast<const AccDataType>(arg.a_m_k_(m, k))); arg.a_element_op_(v_a, arg.a_m_k_(m, k));
arg.b_element_op_(v_b, static_cast<const AccDataType>(arg.b_k_n_(k, n))); arg.b_element_op_(v_b, arg.b_k_n_(k, n));
v_acc += v_a * v_b; v_acc +=
ck::type_convert<AccDataType>(v_a) * ck::type_convert<AccDataType>(v_b);
} }
AccDataType v_c; AccDataType v_c;
arg.c_element_op_(v_c, v_acc); arg.c_element_op_(v_c, v_acc);
arg.c_m_n_(m, n) = v_c; arg.c_m_n_(m, n) = ck::type_convert<CDataType>(v_c);
}; };
make_ParallelTensorFunctor( make_ParallelTensorFunctor(
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
#pragma once #pragma once
#include <vector> #include <vector>
#include "ck/utility/functional2.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
......
...@@ -159,7 +159,7 @@ check_err(const std::vector<T>& out, ...@@ -159,7 +159,7 @@ check_err(const std::vector<T>& out,
const std::vector<T>& ref, const std::vector<T>& ref,
const std::string& msg = "Error: Incorrect results!", const std::string& msg = "Error: Incorrect results!",
double = 0, double = 0,
double = 0) double atol = 0)
{ {
if(out.size() != ref.size()) if(out.size() != ref.size())
{ {
...@@ -179,7 +179,7 @@ check_err(const std::vector<T>& out, ...@@ -179,7 +179,7 @@ check_err(const std::vector<T>& out,
int64_t r = ref[i]; int64_t r = ref[i];
err = std::abs(o - r); err = std::abs(o - r);
if(err > 0) if(err > atol)
{ {
max_err = err > max_err ? err : max_err; max_err = err > max_err ? err : max_err;
err_count++; err_count++;
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment