Commit 46d1d913 authored by Jing Zhang's avatar Jing Zhang
Browse files

fixed comments

parent ea9d7396
......@@ -9,7 +9,6 @@
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_abd_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_contraction_multiple_abd_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/device_memory.hpp"
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <numeric>
......@@ -62,7 +62,7 @@ struct Multiply
__host__ __device__ constexpr void
operator()(ck::half_t& a, const ck::half_t& a0, const float& a1) const
{
a = a0 * a1;
a = ck::type_convert<ck::half_t>(ck::type_convert<float>(a0) * a1);
}
};
......@@ -231,8 +231,8 @@ int main(int argc, char* argv[])
if(!device_op.IsSupportedArgument(argument))
{
throw std::runtime_error(
"wrong! device_gemm with the specified compilation parameters does "
"not support this GEMM problem");
"wrong! device_contraction with the specified compilation parameters does "
"not support this problem");
}
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
......@@ -515,16 +515,6 @@ struct DeviceContractionMultipleABD_Xdl_CShuffle
e_nz_stride_ = e_ms_ns_stride[NumDimM + NumDimN - 1];
}
void Print() const
{
// std::cout << "A[M, K]: " << as_grid_desc_m_k_ << std::endl;
// std::cout << "B[N, K]: " << bs_grid_desc_n_k_ << std::endl;
// static_for<0, NumDTensor, 1>{}(
//[&](auto i) { std::cout << "Ds[M, N]: " << ds_grid_desc_m_n_[i] << std::endl; });
// std::cout << "E[M, N]: " << e_grid_desc_m_n_ << std::endl;
}
// private:
// pointers
typename GridwiseGemm::AsGridPointer p_as_grid_;
typename GridwiseGemm::BsGridPointer p_bs_grid_;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment