Commit 4f566c62 authored by Chao Liu's avatar Chao Liu
Browse files

vector/scalar pointer cast use c-style pointer cast instead of reinterpret_cast

parent 172036d7
......@@ -350,8 +350,8 @@ struct mfma_info<mfma_instr::mfma_f32_32x32x2bf16>
class FloatC>
__device__ FloatC run(const FloatA* a, const FloatB* b, FloatC reg_c) const
{
const auto p_a = reinterpret_cast<const ushort2_t*>(a);
const auto p_b = reinterpret_cast<const ushort2_t*>(b);
const auto p_a = c_style_pointer_cast<const ushort2_t*>(a);
const auto p_b = c_style_pointer_cast<const ushort2_t*>(b);
return intrin_mfma_f32_32x32x2bf16<MPerXdlops, NPerXdlops, AStride, BStride>::run(
p_a, p_b, reg_c);
......@@ -384,8 +384,8 @@ struct mfma_info<mfma_instr::mfma_f32_32x32x4bf16>
class FloatC>
__device__ FloatC run(const FloatA* a, const FloatB* b, FloatC reg_c) const
{
const auto p_a = reinterpret_cast<const ushort2_t*>(a);
const auto p_b = reinterpret_cast<const ushort2_t*>(b);
const auto p_a = c_style_pointer_cast<const ushort2_t*>(a);
const auto p_b = c_style_pointer_cast<const ushort2_t*>(b);
return intrin_mfma_f32_32x32x4bf16(p_a, p_b, reg_c);
}
......@@ -417,8 +417,8 @@ struct mfma_info<mfma_instr::mfma_f32_16x16x8bf16>
class FloatC>
__device__ FloatC run(const FloatA* a, const FloatB* b, FloatC reg_c) const
{
const auto p_a = reinterpret_cast<const ushort2_t*>(a);
const auto p_b = reinterpret_cast<const ushort2_t*>(b);
const auto p_a = c_style_pointer_cast<const ushort2_t*>(a);
const auto p_b = c_style_pointer_cast<const ushort2_t*>(b);
return intrin_mfma_f32_16x16x8bf16(p_a, p_b, reg_c);
}
......@@ -450,8 +450,8 @@ struct mfma_info<mfma_instr::mfma_f32_16x16x2bf16>
class FloatC>
__device__ FloatC run(const FloatA* a, const FloatB* b, FloatC reg_c) const
{
const auto p_a = reinterpret_cast<const ushort2_t*>(a);
const auto p_b = reinterpret_cast<const ushort2_t*>(b);
const auto p_a = c_style_pointer_cast<const ushort2_t*>(a);
const auto p_b = c_style_pointer_cast<const ushort2_t*>(b);
return intrin_mfma_f32_16x16x2bf16<MPerXdlops, NPerXdlops>(p_a, p_b, reg_c);
}
......@@ -483,8 +483,8 @@ struct mfma_info<mfma_instr::mfma_f32_4x4x2bf16>
class FloatC>
__device__ FloatC run(const FloatA* a, const FloatB* b, FloatC reg_c) const
{
const auto p_a = reinterpret_cast<const ushort2_t*>(a);
const auto p_b = reinterpret_cast<const ushort2_t*>(b);
const auto p_a = c_style_pointer_cast<const ushort2_t*>(a);
const auto p_b = c_style_pointer_cast<const ushort2_t*>(b);
return intrin_mfma_f32_4x4x2bf16<MPerXdlops, NPerXdlops>::run(p_a, p_b, reg_c);
}
......
......@@ -2,6 +2,7 @@
#define CK_AMD_INLINE_ASM_HPP
#include "data_type.hpp"
#include "c_style_pointer_cast.hpp"
namespace ck {
......@@ -53,9 +54,9 @@ __device__ void
amd_assembly_outer_product_1x2(half4_t a, half4_t b0, half4_t b1, float& c0, float& c1)
{
// TODO remove pointer casting
const half2_t* p_a_half2 = reinterpret_cast<const half2_t*>(&a);
const half2_t* p_b0_half2 = reinterpret_cast<const half2_t*>(&b0);
const half2_t* p_b1_half2 = reinterpret_cast<const half2_t*>(&b1);
const half2_t* p_a_half2 = c_style_pointer_cast<const half2_t*>(&a);
const half2_t* p_b0_half2 = c_style_pointer_cast<const half2_t*>(&b0);
const half2_t* p_b1_half2 = c_style_pointer_cast<const half2_t*>(&b1);
// do dot2 two times
asm volatile("\n \
......@@ -114,11 +115,11 @@ __device__ void amd_assembly_outer_product_1x4(half4_t a,
float& c3)
{
// TODO remove pointer casting
const half2_t* p_a_half2 = reinterpret_cast<const half2_t*>(&a);
const half2_t* p_b0_half2 = reinterpret_cast<const half2_t*>(&b0);
const half2_t* p_b1_half2 = reinterpret_cast<const half2_t*>(&b1);
const half2_t* p_b2_half2 = reinterpret_cast<const half2_t*>(&b2);
const half2_t* p_b3_half2 = reinterpret_cast<const half2_t*>(&b3);
const half2_t* p_a_half2 = c_style_pointer_cast<const half2_t*>(&a);
const half2_t* p_b0_half2 = c_style_pointer_cast<const half2_t*>(&b0);
const half2_t* p_b1_half2 = c_style_pointer_cast<const half2_t*>(&b1);
const half2_t* p_b2_half2 = c_style_pointer_cast<const half2_t*>(&b2);
const half2_t* p_b3_half2 = c_style_pointer_cast<const half2_t*>(&b3);
// do dot2 two times
asm volatile("\n \
......@@ -160,11 +161,11 @@ __device__ void amd_assembly_outer_product_1x4(half8_t a,
{
// TODO remove pointer casting
const half4_t* p_a_half4 = reinterpret_cast<const half4_t*>(&a);
const half4_t* p_b0_half4 = reinterpret_cast<const half4_t*>(&b0);
const half4_t* p_b1_half4 = reinterpret_cast<const half4_t*>(&b1);
const half4_t* p_b2_half4 = reinterpret_cast<const half4_t*>(&b2);
const half4_t* p_b3_half4 = reinterpret_cast<const half4_t*>(&b3);
const half4_t* p_a_half4 = c_style_pointer_cast<const half4_t*>(&a);
const half4_t* p_b0_half4 = c_style_pointer_cast<const half4_t*>(&b0);
const half4_t* p_b1_half4 = c_style_pointer_cast<const half4_t*>(&b1);
const half4_t* p_b2_half4 = c_style_pointer_cast<const half4_t*>(&b2);
const half4_t* p_b3_half4 = c_style_pointer_cast<const half4_t*>(&b3);
amd_assembly_outer_product_1x4(
p_a_half4[0], p_b0_half4[0], p_b1_half4[0], p_b2_half4[0], p_b3_half4[0], c0, c1, c2, c3);
......@@ -184,11 +185,11 @@ __device__ void amd_assembly_outer_product_1x4(half16_t a,
float& c3)
{
// TODO remove pointer casting
const half8_t* p_a_half8 = reinterpret_cast<const half8_t*>(&a);
const half8_t* p_b0_half8 = reinterpret_cast<const half8_t*>(&b0);
const half8_t* p_b1_half8 = reinterpret_cast<const half8_t*>(&b1);
const half8_t* p_b2_half8 = reinterpret_cast<const half8_t*>(&b2);
const half8_t* p_b3_half8 = reinterpret_cast<const half8_t*>(&b3);
const half8_t* p_a_half8 = c_style_pointer_cast<const half8_t*>(&a);
const half8_t* p_b0_half8 = c_style_pointer_cast<const half8_t*>(&b0);
const half8_t* p_b1_half8 = c_style_pointer_cast<const half8_t*>(&b1);
const half8_t* p_b2_half8 = c_style_pointer_cast<const half8_t*>(&b2);
const half8_t* p_b3_half8 = c_style_pointer_cast<const half8_t*>(&b3);
amd_assembly_outer_product_1x4(
p_a_half8[0], p_b0_half8[0], p_b1_half8[0], p_b2_half8[0], p_b3_half8[0], c0, c1, c2, c3);
......
......@@ -51,7 +51,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_dlops_nhwc_kyxc_nhw
const auto out_n_ho_wo_k_desc =
make_dynamic_naive_tensor_descriptor_packed_v2(out_n_ho_wo_k_lengths);
#if 1
#if 0
// [M, N, K0, K1] = [128, 128, 8, 1] for fp32
// cdata = 64, BlockSize = 256
constexpr index_t BlockSize = 256;
......@@ -81,7 +81,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_dlops_nhwc_kyxc_nhw
using GemmBBlockTransferDstVectorTensorLengths_K0_N0_N1_K1 = Sequence<1, 1, 1, 1>;
constexpr index_t GemmCThreadTransferDstScalarPerVector_N11 = 4;
#elif 1
#elif 0
// [M, N, K0, K1] = [128, 128, 8, 2] for fp16
// cdata = 64, BlockSize = 256
constexpr index_t BlockSize = 256;
......
......@@ -34,7 +34,7 @@ std::ostream& LogRangeAsType(std::ostream& os, Range&& range, std::string delim)
first = false;
else
os << delim;
os << T{v};
os << static_cast<T>(v);
}
return os;
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment