Commit 20e47518 authored by fsx950223's avatar fsx950223
Browse files

merge upstream

parents 69224aac 67f39ad1
......@@ -500,6 +500,7 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout,
for(std::size_t i = 0; i < arg.gemm_desc_kernel_arg_.size(); i++)
{
#if DEBUG_LOG
std::cout << "group: " << i << " arg.a_grid_desc_ak0_m_ak1_{"
<< arg.gemm_desc_kernel_arg_[i].a_grid_desc_ak0_m_ak1_.GetLength(I0)
<< ", "
......@@ -520,6 +521,7 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout,
<< arg.gemm_desc_kernel_arg_[i].e_grid_desc_m_n_.GetLength(I0) << ", "
<< arg.gemm_desc_kernel_arg_[i].e_grid_desc_m_n_.GetLength(I1) << "}"
<< std::endl;
#endif
if(!GridwiseGemm::CheckValidity(arg.gemm_desc_kernel_arg_[i].a_grid_desc_m_k_,
arg.gemm_desc_kernel_arg_[i].b_grid_desc_n_k_,
......
......@@ -95,6 +95,8 @@ struct Scale
y = scale_ * x;
};
__host__ __device__ auto Value() const { return scale_; }
float scale_;
};
......
This diff is collapsed.
......@@ -168,6 +168,22 @@ __device__ double exp<double>(double x)
return exp(x);
}
// disallow implicit type casting
template <typename T>
__device__ T log(T x);
template <>
__device__ float log<float>(float x)
{
return __logf(x);
}
template <>
__device__ double log<double>(double x)
{
return log(x);
}
// greatest common divisor, aka highest common factor
__host__ __device__ constexpr index_t gcd(index_t x, index_t y)
{
......
......@@ -3,7 +3,9 @@
#pragma once
#ifndef __HIP_DEVICE_COMPILE__
#include <cmath>
#endif
#include "ck/utility/data_type.hpp"
#include "ck/utility/type.hpp"
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment