Commit 23f99eb4 authored by Jing Zhang's avatar Jing Zhang
Browse files

fixed

parent 6e3c786e
......@@ -91,21 +91,21 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
};
auto f_get_default_stride =
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
if(stride == 0)
[](std::size_t row, std::size_t col, ck::index_t stride, auto layout) {
if(stride == -1)
{
// give a chance if stride is zero, return a default packed stride
// give a chance if stride is -1, return a default packed stride
if constexpr(std::is_same_v<decltype(layout), ck::tensor_layout::gemm::RowMajor>)
{
return col;
return static_cast<std::size_t>(col);
}
else
{
return row;
return static_cast<std::size_t>(row);
}
}
else
return stride;
return static_cast<std::size_t>(stride);
};
StrideA = f_get_default_stride(M, K, StrideA, ALayout{});
......
......@@ -266,18 +266,18 @@ struct Tensor
using Data = std::vector<T>;
template <typename X>
Tensor(std::initializer_list<X> lens) : mDesc(lens), mData(mDesc.GetElementSpaceSize())
Tensor(std::initializer_list<X> lens) : mDesc(lens), mData(GetElementSpaceSize())
{
}
template <typename X, typename Y>
Tensor(std::initializer_list<X> lens, std::initializer_list<Y> strides)
: mDesc(lens, strides), mData(mDesc.GetElementSpaceSize())
: mDesc(lens, strides), mData(GetElementSpaceSize())
{
}
template <typename Lengths>
Tensor(const Lengths& lens) : mDesc(lens), mData(mDesc.GetElementSpaceSize())
Tensor(const Lengths& lens) : mDesc(lens), mData(GetElementSpaceSize())
{
}
......@@ -287,7 +287,7 @@ struct Tensor
{
}
Tensor(const Descriptor& desc) : mDesc(desc), mData(mDesc.GetElementSpaceSize()) {}
Tensor(const Descriptor& desc) : mDesc(desc), mData(GetElementSpaceSize()) {}
template <typename OutT>
Tensor<OutT> CopyAsType() const
......@@ -324,7 +324,7 @@ struct Tensor
std::size_t GetElementSpaceSize() const
{
if constexpr(ck::is_same_v<T, ck::pk_i4_t>)
if constexpr(ck::is_same_v<ck::remove_cvref_t<T>, ck::pk_i4_t>)
return mDesc.GetElementSpaceSize() / 2;
else
return mDesc.GetElementSpaceSize();
......@@ -475,7 +475,7 @@ struct Tensor
template <typename... Is>
std::size_t GetOffsetFromMultiIndex(Is... is) const
{
if constexpr(ck::is_same_v<T, ck::pk_i4_t>)
if constexpr(ck::is_same_v<ck::remove_cvref_t<T>, ck::pk_i4_t>)
{
return mDesc.GetOffsetFromMultiIndex(is...) / 2;
}
......@@ -488,7 +488,7 @@ struct Tensor
template <typename... Is>
T& operator()(Is... is)
{
if constexpr(ck::is_same_v<T, ck::pk_i4_t>)
if constexpr(ck::is_same_v<ck::remove_cvref_t<T>, ck::pk_i4_t>)
{
return mData[mDesc.GetOffsetFromMultiIndex(is...) / 2];
}
......@@ -501,7 +501,7 @@ struct Tensor
template <typename... Is>
const T& operator()(Is... is) const
{
if constexpr(ck::is_same_v<T, ck::pk_i4_t>)
if constexpr(ck::is_same_v<ck::remove_cvref_t<T>, ck::pk_i4_t>)
{
return mData[mDesc.GetOffsetFromMultiIndex(is...) / 2];
}
......@@ -513,7 +513,7 @@ struct Tensor
T& operator()(std::vector<std::size_t> idx)
{
if constexpr(ck::is_same_v<T, ck::pk_i4_t>)
if constexpr(ck::is_same_v<ck::remove_cvref_t<T>, ck::pk_i4_t>)
{
return mData[mDesc.GetOffsetFromMultiIndex(idx) / 2];
}
......@@ -525,7 +525,7 @@ struct Tensor
const T& operator()(std::vector<std::size_t> idx) const
{
if constexpr(ck::is_same_v<T, ck::pk_i4_t>)
if constexpr(ck::is_same_v<ck::remove_cvref_t<T>, ck::pk_i4_t>)
{
return mData[mDesc.GetOffsetFromMultiIndex(idx) / 2];
}
......
......@@ -45,24 +45,6 @@ __global__ void
karg.p_c_grid + splitk_batch_offset.c_reduce_offset,
p_shared,
karg);
// int q = 0x01234567;
// ck::vector_type<ck::bhalf_t, 8> res;
// res.template AsType<ck::bhalf4_t>()(ck::Number<0>{}) = ck::pki4_to_bhalf4(q >> 16);
// res.template AsType<ck::bhalf4_t>()(ck::Number<1>{}) = ck::pki4_to_bhalf4(q);
// if(threadIdx.x == 0 && blockIdx.x == 0)
// printf("%f %f %f %f %f %f %f %f\n",
// ck::type_convert<float>(res.template AsType<ck::bhalf_t>()[Number<0>{}]),
// ck::type_convert<float>(res.template AsType<ck::bhalf_t>()[Number<1>{}]),
// ck::type_convert<float>(res.template AsType<ck::bhalf_t>()[Number<2>{}]),
// ck::type_convert<float>(res.template AsType<ck::bhalf_t>()[Number<3>{}]),
// ck::type_convert<float>(res.template AsType<ck::bhalf_t>()[Number<4>{}]),
// ck::type_convert<float>(res.template AsType<ck::bhalf_t>()[Number<5>{}]),
// ck::type_convert<float>(res.template AsType<ck::bhalf_t>()[Number<6>{}]),
// ck::type_convert<float>(res.template AsType<ck::bhalf_t>()[Number<7>{}])
//);
#else
ignore = karg;
#endif // end of if (defined(__gfx9__))
......
......@@ -1870,12 +1870,12 @@ using pk_i4x2_t = typename vector_type<pk_i4_t, 2>::type;
using pk_i4x4_t = typename vector_type<pk_i4_t, 4>::type;
// u8
//using uint8x2_t = typename vector_type<uint8_t, 2>::type;
//using uint8x4_t = typename vector_type<uint8_t, 4>::type;
//using uint8x8_t = typename vector_type<uint8_t, 8>::type;
//using uint8x16_t = typename vector_type<uint8_t, 16>::type;
//using uint8x32_t = typename vector_type<uint8_t, 32>::type;
//using uint8x64_t = typename vector_type<uint8_t, 64>::type;
// using uint8x2_t = typename vector_type<uint8_t, 2>::type;
// using uint8x4_t = typename vector_type<uint8_t, 4>::type;
// using uint8x8_t = typename vector_type<uint8_t, 8>::type;
// using uint8x16_t = typename vector_type<uint8_t, 16>::type;
// using uint8x32_t = typename vector_type<uint8_t, 32>::type;
// using uint8x64_t = typename vector_type<uint8_t, 64>::type;
template <typename T>
struct NumericLimits
......
......@@ -17,7 +17,7 @@ fi
cmake \
-D CMAKE_PREFIX_PATH=/opt/rocm \
-D CMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \
-D CMAKE_HIP_FLAGS="-gline-tables-only -Xclang -mllvm -Xclang -enable-post-misched=0 -std=c++17 -O3 -ftemplate-backtrace-limit=0 -fPIE -Wno-gnu-line-marker" \
-D CMAKE_HIP_FLAGS="-g -Xclang -mllvm -Xclang -enable-post-misched=0 -std=c++17 -O0 -ftemplate-backtrace-limit=0 -fPIE -Wno-gnu-line-marker" \
-D CMAKE_BUILD_TYPE=Release \
-D BUILD_DEV=ON \
-D GPU_TARGETS=$GPU_TARGETS \
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment