Commit 96e212c1 authored by Chao Liu's avatar Chao Liu
Browse files

clean up

parent 4b306e5b
......@@ -606,7 +606,7 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src
dst_wave_addr_offset + 4 * sizeof(half_t),
0);
#else
llvm_amdgcn_raw_buffer_store_fp32x4(as_type<float4_t>(src_thread_data),
llvm_amdgcn_raw_buffer_store_fp32x4(bit_cast<float4_t>(src_thread_data),
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
......
......@@ -133,7 +133,7 @@
// workaround for compiler gnerating inefficient ds_write instructions
#ifndef CK_WORKAROUND_SWDEV_XXXXXX_INT8_DS_WRITE_ISSUE
#define CK_WORKAROUND_SWDEV_XXXXXX_INT8_DS_WRITE_ISSUE 0
#define CK_WORKAROUND_SWDEV_XXXXXX_INT8_DS_WRITE_ISSUE 1
#endif
namespace ck {
......
......@@ -137,7 +137,7 @@ struct vector_type<T, 1>
union
{
T d1_;
StaticallyIndexedArray_v2<T, 1> d1x1_;
StaticallyIndexedArray<T, 1> d1x1_;
} data_;
__host__ __device__ constexpr vector_type() : data_{type{0}} {}
......@@ -172,8 +172,8 @@ struct vector_type<T, 2>
union
{
d2_t d2_;
StaticallyIndexedArray_v2<d1_t, 2> d1x2_;
StaticallyIndexedArray_v2<d2_t, 1> d2x1_;
StaticallyIndexedArray<d1_t, 2> d1x2_;
StaticallyIndexedArray<d2_t, 1> d2x1_;
} data_;
__host__ __device__ constexpr vector_type() : data_{type{0}} {}
......@@ -223,9 +223,9 @@ struct vector_type<T, 4>
union
{
d4_t d4_;
StaticallyIndexedArray_v2<d1_t, 4> d1x4_;
StaticallyIndexedArray_v2<d2_t, 2> d2x2_;
StaticallyIndexedArray_v2<d4_t, 1> d4x1_;
StaticallyIndexedArray<d1_t, 4> d1x4_;
StaticallyIndexedArray<d2_t, 2> d2x2_;
StaticallyIndexedArray<d4_t, 1> d4x1_;
} data_;
__host__ __device__ constexpr vector_type() : data_{type{0}} {}
......@@ -286,10 +286,10 @@ struct vector_type<T, 8>
union
{
d8_t d8_;
StaticallyIndexedArray_v2<d1_t, 8> d1x8_;
StaticallyIndexedArray_v2<d2_t, 4> d2x4_;
StaticallyIndexedArray_v2<d4_t, 2> d4x2_;
StaticallyIndexedArray_v2<d8_t, 1> d8x1_;
StaticallyIndexedArray<d1_t, 8> d1x8_;
StaticallyIndexedArray<d2_t, 4> d2x4_;
StaticallyIndexedArray<d4_t, 2> d4x2_;
StaticallyIndexedArray<d8_t, 1> d8x1_;
} data_;
__host__ __device__ constexpr vector_type() : data_{type{0}} {}
......@@ -361,11 +361,11 @@ struct vector_type<T, 16>
union
{
d16_t d16_;
StaticallyIndexedArray_v2<d1_t, 16> d1x16_;
StaticallyIndexedArray_v2<d2_t, 8> d2x8_;
StaticallyIndexedArray_v2<d4_t, 4> d4x4_;
StaticallyIndexedArray_v2<d8_t, 2> d8x2_;
StaticallyIndexedArray_v2<d16_t, 1> d16x1_;
StaticallyIndexedArray<d1_t, 16> d1x16_;
StaticallyIndexedArray<d2_t, 8> d2x8_;
StaticallyIndexedArray<d4_t, 4> d4x4_;
StaticallyIndexedArray<d8_t, 2> d8x2_;
StaticallyIndexedArray<d16_t, 1> d16x1_;
} data_;
__host__ __device__ constexpr vector_type() : data_{type{0}} {}
......@@ -448,12 +448,12 @@ struct vector_type<T, 32>
union
{
d32_t d32_;
StaticallyIndexedArray_v2<d1_t, 32> d1x32_;
StaticallyIndexedArray_v2<d2_t, 16> d2x16_;
StaticallyIndexedArray_v2<d4_t, 8> d4x8_;
StaticallyIndexedArray_v2<d8_t, 4> d8x4_;
StaticallyIndexedArray_v2<d16_t, 2> d16x2_;
StaticallyIndexedArray_v2<d32_t, 1> d32x1_;
StaticallyIndexedArray<d1_t, 32> d1x32_;
StaticallyIndexedArray<d2_t, 16> d2x16_;
StaticallyIndexedArray<d4_t, 8> d4x8_;
StaticallyIndexedArray<d8_t, 4> d8x4_;
StaticallyIndexedArray<d16_t, 2> d16x2_;
StaticallyIndexedArray<d32_t, 1> d32x1_;
} data_;
__host__ __device__ constexpr vector_type() : data_{type{0}} {}
......@@ -545,13 +545,13 @@ struct vector_type<T, 64>
union
{
d64_t d64_;
StaticallyIndexedArray_v2<d1_t, 64> d1x64_;
StaticallyIndexedArray_v2<d2_t, 32> d2x32_;
StaticallyIndexedArray_v2<d4_t, 16> d4x16_;
StaticallyIndexedArray_v2<d8_t, 8> d8x8_;
StaticallyIndexedArray_v2<d16_t, 4> d16x4_;
StaticallyIndexedArray_v2<d32_t, 2> d32x2_;
StaticallyIndexedArray_v2<d64_t, 1> d64x1_;
StaticallyIndexedArray<d1_t, 64> d1x64_;
StaticallyIndexedArray<d2_t, 32> d2x32_;
StaticallyIndexedArray<d4_t, 16> d4x16_;
StaticallyIndexedArray<d8_t, 8> d8x8_;
StaticallyIndexedArray<d16_t, 4> d16x4_;
StaticallyIndexedArray<d32_t, 2> d32x2_;
StaticallyIndexedArray<d64_t, 1> d64x1_;
} data_;
__host__ __device__ constexpr vector_type() : data_{type{0}} {}
......@@ -654,14 +654,14 @@ struct vector_type<T, 128>
union
{
d128_t d128_;
StaticallyIndexedArray_v2<d1_t, 128> d1x128_;
StaticallyIndexedArray_v2<d2_t, 64> d2x64_;
StaticallyIndexedArray_v2<d4_t, 32> d4x32_;
StaticallyIndexedArray_v2<d8_t, 16> d8x16_;
StaticallyIndexedArray_v2<d16_t, 8> d16x8_;
StaticallyIndexedArray_v2<d32_t, 4> d32x4_;
StaticallyIndexedArray_v2<d64_t, 2> d64x2_;
StaticallyIndexedArray_v2<d128_t, 1> d128x1_;
StaticallyIndexedArray<d1_t, 128> d1x128_;
StaticallyIndexedArray<d2_t, 64> d2x64_;
StaticallyIndexedArray<d4_t, 32> d4x32_;
StaticallyIndexedArray<d8_t, 16> d8x16_;
StaticallyIndexedArray<d16_t, 8> d16x8_;
StaticallyIndexedArray<d32_t, 4> d32x4_;
StaticallyIndexedArray<d64_t, 2> d64x2_;
StaticallyIndexedArray<d128_t, 1> d128x1_;
} data_;
__host__ __device__ constexpr vector_type() : data_{type{0}} {}
......@@ -773,15 +773,15 @@ struct vector_type<T, 256>
union
{
d256_t d256_;
StaticallyIndexedArray_v2<d1_t, 256> d1x256_;
StaticallyIndexedArray_v2<d2_t, 128> d2x128_;
StaticallyIndexedArray_v2<d4_t, 64> d4x64_;
StaticallyIndexedArray_v2<d8_t, 32> d8x32_;
StaticallyIndexedArray_v2<d16_t, 16> d16x16_;
StaticallyIndexedArray_v2<d32_t, 8> d32x8_;
StaticallyIndexedArray_v2<d64_t, 4> d64x4_;
StaticallyIndexedArray_v2<d128_t, 2> d128x2_;
StaticallyIndexedArray_v2<d256_t, 1> d256x1_;
StaticallyIndexedArray<d1_t, 256> d1x256_;
StaticallyIndexedArray<d2_t, 128> d2x128_;
StaticallyIndexedArray<d4_t, 64> d4x64_;
StaticallyIndexedArray<d8_t, 32> d8x32_;
StaticallyIndexedArray<d16_t, 16> d16x16_;
StaticallyIndexedArray<d32_t, 8> d32x8_;
StaticallyIndexedArray<d64_t, 4> d64x4_;
StaticallyIndexedArray<d128_t, 2> d128x2_;
StaticallyIndexedArray<d256_t, 1> d256x1_;
} data_;
__host__ __device__ constexpr vector_type() : data_{type{0}} {}
......
......@@ -9,7 +9,6 @@
#include "device.hpp"
#include "host_tensor.hpp"
#include "host_tensor_generator.hpp"
#include "gemm_common.hpp"
#include "host_gemm.hpp"
#include "device_tensor.hpp"
#include "device_base.hpp"
......@@ -139,12 +138,12 @@ int main(int argc, char* argv[])
{
case 0: break;
case 1:
a_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5});
b_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5});
a_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5});
b_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5});
break;
default:
a_m_k.GenerateTensorValue(GeneratorTensor_3<float>{0.0, 1.0});
b_k_n.GenerateTensorValue(GeneratorTensor_3<float>{-0.5, 0.5});
a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
b_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
}
DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace());
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment