Commit 23f99eb4 authored by Jing Zhang's avatar Jing Zhang
Browse files

fixed

parent 6e3c786e
...@@ -91,21 +91,21 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) ...@@ -91,21 +91,21 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
}; };
auto f_get_default_stride = auto f_get_default_stride =
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) { [](std::size_t row, std::size_t col, ck::index_t stride, auto layout) {
if(stride == 0) if(stride == -1)
{ {
// give a chance if stride is zero, return a default packed stride // give a chance if stride is -1, return a default packed stride
if constexpr(std::is_same_v<decltype(layout), ck::tensor_layout::gemm::RowMajor>) if constexpr(std::is_same_v<decltype(layout), ck::tensor_layout::gemm::RowMajor>)
{ {
return col; return static_cast<std::size_t>(col);
} }
else else
{ {
return row; return static_cast<std::size_t>(row);
} }
} }
else else
return stride; return static_cast<std::size_t>(stride);
}; };
StrideA = f_get_default_stride(M, K, StrideA, ALayout{}); StrideA = f_get_default_stride(M, K, StrideA, ALayout{});
......
...@@ -266,18 +266,18 @@ struct Tensor ...@@ -266,18 +266,18 @@ struct Tensor
using Data = std::vector<T>; using Data = std::vector<T>;
template <typename X> template <typename X>
Tensor(std::initializer_list<X> lens) : mDesc(lens), mData(mDesc.GetElementSpaceSize()) Tensor(std::initializer_list<X> lens) : mDesc(lens), mData(GetElementSpaceSize())
{ {
} }
template <typename X, typename Y> template <typename X, typename Y>
Tensor(std::initializer_list<X> lens, std::initializer_list<Y> strides) Tensor(std::initializer_list<X> lens, std::initializer_list<Y> strides)
: mDesc(lens, strides), mData(mDesc.GetElementSpaceSize()) : mDesc(lens, strides), mData(GetElementSpaceSize())
{ {
} }
template <typename Lengths> template <typename Lengths>
Tensor(const Lengths& lens) : mDesc(lens), mData(mDesc.GetElementSpaceSize()) Tensor(const Lengths& lens) : mDesc(lens), mData(GetElementSpaceSize())
{ {
} }
...@@ -287,7 +287,7 @@ struct Tensor ...@@ -287,7 +287,7 @@ struct Tensor
{ {
} }
Tensor(const Descriptor& desc) : mDesc(desc), mData(mDesc.GetElementSpaceSize()) {} Tensor(const Descriptor& desc) : mDesc(desc), mData(GetElementSpaceSize()) {}
template <typename OutT> template <typename OutT>
Tensor<OutT> CopyAsType() const Tensor<OutT> CopyAsType() const
...@@ -324,7 +324,7 @@ struct Tensor ...@@ -324,7 +324,7 @@ struct Tensor
std::size_t GetElementSpaceSize() const std::size_t GetElementSpaceSize() const
{ {
if constexpr(ck::is_same_v<T, ck::pk_i4_t>) if constexpr(ck::is_same_v<ck::remove_cvref_t<T>, ck::pk_i4_t>)
return mDesc.GetElementSpaceSize() / 2; return mDesc.GetElementSpaceSize() / 2;
else else
return mDesc.GetElementSpaceSize(); return mDesc.GetElementSpaceSize();
...@@ -475,7 +475,7 @@ struct Tensor ...@@ -475,7 +475,7 @@ struct Tensor
template <typename... Is> template <typename... Is>
std::size_t GetOffsetFromMultiIndex(Is... is) const std::size_t GetOffsetFromMultiIndex(Is... is) const
{ {
if constexpr(ck::is_same_v<T, ck::pk_i4_t>) if constexpr(ck::is_same_v<ck::remove_cvref_t<T>, ck::pk_i4_t>)
{ {
return mDesc.GetOffsetFromMultiIndex(is...) / 2; return mDesc.GetOffsetFromMultiIndex(is...) / 2;
} }
...@@ -488,7 +488,7 @@ struct Tensor ...@@ -488,7 +488,7 @@ struct Tensor
template <typename... Is> template <typename... Is>
T& operator()(Is... is) T& operator()(Is... is)
{ {
if constexpr(ck::is_same_v<T, ck::pk_i4_t>) if constexpr(ck::is_same_v<ck::remove_cvref_t<T>, ck::pk_i4_t>)
{ {
return mData[mDesc.GetOffsetFromMultiIndex(is...) / 2]; return mData[mDesc.GetOffsetFromMultiIndex(is...) / 2];
} }
...@@ -501,7 +501,7 @@ struct Tensor ...@@ -501,7 +501,7 @@ struct Tensor
template <typename... Is> template <typename... Is>
const T& operator()(Is... is) const const T& operator()(Is... is) const
{ {
if constexpr(ck::is_same_v<T, ck::pk_i4_t>) if constexpr(ck::is_same_v<ck::remove_cvref_t<T>, ck::pk_i4_t>)
{ {
return mData[mDesc.GetOffsetFromMultiIndex(is...) / 2]; return mData[mDesc.GetOffsetFromMultiIndex(is...) / 2];
} }
...@@ -513,7 +513,7 @@ struct Tensor ...@@ -513,7 +513,7 @@ struct Tensor
T& operator()(std::vector<std::size_t> idx) T& operator()(std::vector<std::size_t> idx)
{ {
if constexpr(ck::is_same_v<T, ck::pk_i4_t>) if constexpr(ck::is_same_v<ck::remove_cvref_t<T>, ck::pk_i4_t>)
{ {
return mData[mDesc.GetOffsetFromMultiIndex(idx) / 2]; return mData[mDesc.GetOffsetFromMultiIndex(idx) / 2];
} }
...@@ -525,7 +525,7 @@ struct Tensor ...@@ -525,7 +525,7 @@ struct Tensor
const T& operator()(std::vector<std::size_t> idx) const const T& operator()(std::vector<std::size_t> idx) const
{ {
if constexpr(ck::is_same_v<T, ck::pk_i4_t>) if constexpr(ck::is_same_v<ck::remove_cvref_t<T>, ck::pk_i4_t>)
{ {
return mData[mDesc.GetOffsetFromMultiIndex(idx) / 2]; return mData[mDesc.GetOffsetFromMultiIndex(idx) / 2];
} }
......
...@@ -45,24 +45,6 @@ __global__ void ...@@ -45,24 +45,6 @@ __global__ void
karg.p_c_grid + splitk_batch_offset.c_reduce_offset, karg.p_c_grid + splitk_batch_offset.c_reduce_offset,
p_shared, p_shared,
karg); karg);
// int q = 0x01234567;
// ck::vector_type<ck::bhalf_t, 8> res;
// res.template AsType<ck::bhalf4_t>()(ck::Number<0>{}) = ck::pki4_to_bhalf4(q >> 16);
// res.template AsType<ck::bhalf4_t>()(ck::Number<1>{}) = ck::pki4_to_bhalf4(q);
// if(threadIdx.x == 0 && blockIdx.x == 0)
// printf("%f %f %f %f %f %f %f %f\n",
// ck::type_convert<float>(res.template AsType<ck::bhalf_t>()[Number<0>{}]),
// ck::type_convert<float>(res.template AsType<ck::bhalf_t>()[Number<1>{}]),
// ck::type_convert<float>(res.template AsType<ck::bhalf_t>()[Number<2>{}]),
// ck::type_convert<float>(res.template AsType<ck::bhalf_t>()[Number<3>{}]),
// ck::type_convert<float>(res.template AsType<ck::bhalf_t>()[Number<4>{}]),
// ck::type_convert<float>(res.template AsType<ck::bhalf_t>()[Number<5>{}]),
// ck::type_convert<float>(res.template AsType<ck::bhalf_t>()[Number<6>{}]),
// ck::type_convert<float>(res.template AsType<ck::bhalf_t>()[Number<7>{}])
//);
#else #else
ignore = karg; ignore = karg;
#endif // end of if (defined(__gfx9__)) #endif // end of if (defined(__gfx9__))
......
...@@ -1870,12 +1870,12 @@ using pk_i4x2_t = typename vector_type<pk_i4_t, 2>::type; ...@@ -1870,12 +1870,12 @@ using pk_i4x2_t = typename vector_type<pk_i4_t, 2>::type;
using pk_i4x4_t = typename vector_type<pk_i4_t, 4>::type; using pk_i4x4_t = typename vector_type<pk_i4_t, 4>::type;
// u8 // u8
//using uint8x2_t = typename vector_type<uint8_t, 2>::type; // using uint8x2_t = typename vector_type<uint8_t, 2>::type;
//using uint8x4_t = typename vector_type<uint8_t, 4>::type; // using uint8x4_t = typename vector_type<uint8_t, 4>::type;
//using uint8x8_t = typename vector_type<uint8_t, 8>::type; // using uint8x8_t = typename vector_type<uint8_t, 8>::type;
//using uint8x16_t = typename vector_type<uint8_t, 16>::type; // using uint8x16_t = typename vector_type<uint8_t, 16>::type;
//using uint8x32_t = typename vector_type<uint8_t, 32>::type; // using uint8x32_t = typename vector_type<uint8_t, 32>::type;
//using uint8x64_t = typename vector_type<uint8_t, 64>::type; // using uint8x64_t = typename vector_type<uint8_t, 64>::type;
template <typename T> template <typename T>
struct NumericLimits struct NumericLimits
......
...@@ -17,7 +17,7 @@ fi ...@@ -17,7 +17,7 @@ fi
cmake \ cmake \
-D CMAKE_PREFIX_PATH=/opt/rocm \ -D CMAKE_PREFIX_PATH=/opt/rocm \
-D CMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \ -D CMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \
-D CMAKE_HIP_FLAGS="-gline-tables-only -Xclang -mllvm -Xclang -enable-post-misched=0 -std=c++17 -O3 -ftemplate-backtrace-limit=0 -fPIE -Wno-gnu-line-marker" \ -D CMAKE_HIP_FLAGS="-g -Xclang -mllvm -Xclang -enable-post-misched=0 -std=c++17 -O0 -ftemplate-backtrace-limit=0 -fPIE -Wno-gnu-line-marker" \
-D CMAKE_BUILD_TYPE=Release \ -D CMAKE_BUILD_TYPE=Release \
-D BUILD_DEV=ON \ -D BUILD_DEV=ON \
-D GPU_TARGETS=$GPU_TARGETS \ -D GPU_TARGETS=$GPU_TARGETS \
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment