Commit 049cc8af authored by aska-0096's avatar aska-0096
Browse files

change arch limitation

parent 7dca8463
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
#define CK_AMD_WMMA_HPP #define CK_AMD_WMMA_HPP
#include "data_type.hpp" #include "data_type.hpp"
// TODO: Add arch limitation
namespace ck { namespace ck {
// wave32 only // wave32 only
......
...@@ -52,4 +52,6 @@ add_subdirectory(block_to_ctile_map) ...@@ -52,4 +52,6 @@ add_subdirectory(block_to_ctile_map)
add_subdirectory(softmax) add_subdirectory(softmax)
add_subdirectory(normalization) add_subdirectory(normalization)
add_subdirectory(data_type) add_subdirectory(data_type)
add_subdirectory(wmma_op) if(GPU_TARGETS MATCHES "gfx1100")
add_subdirectory(wmma_op)
endif()
...@@ -16,7 +16,6 @@ ...@@ -16,7 +16,6 @@
namespace ck { namespace ck {
__global__ void matmul(const half_t* a, const half_t* b, float* c) __global__ void matmul(const half_t* a, const half_t* b, float* c)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx1100__))
const int lIdx = threadIdx.x; const int lIdx = threadIdx.x;
// a and b fragments are stored in 8 VGPRs each, in packed format, so 16 elements each for a and // a and b fragments are stored in 8 VGPRs each, in packed format, so 16 elements each for a and
...@@ -53,16 +52,10 @@ __global__ void matmul(const half_t* a, const half_t* b, float* c) ...@@ -53,16 +52,10 @@ __global__ void matmul(const half_t* a, const half_t* b, float* c)
// store results from unpacked c_thread_buf_ output // store results from unpacked c_thread_buf_ output
c[16 * r + lane] = c_thread_buf_[Number<ele>{}]; c[16 * r + lane] = c_thread_buf_[Number<ele>{}];
}); });
#else
ignore = a;
ignore = b;
ignore = c;
#endif // end of if (defined(__gfx1100__))
} }
__global__ void matmul_swizzle_a(const half_t* a, const half_t* b, float* c) __global__ void matmul_swizzle_a(const half_t* a, const half_t* b, float* c)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx1100__))
const int lIdx = threadIdx.x; const int lIdx = threadIdx.x;
half16_t a_frag = {}; half16_t a_frag = {};
...@@ -92,11 +85,6 @@ __global__ void matmul_swizzle_a(const half_t* a, const half_t* b, float* c) ...@@ -92,11 +85,6 @@ __global__ void matmul_swizzle_a(const half_t* a, const half_t* b, float* c)
const int r = ele; const int r = ele;
c[16 * 8 * blk + 16 * r + lane] = c_thread_buf_[Number<ele>{}]; c[16 * 8 * blk + 16 * r + lane] = c_thread_buf_[Number<ele>{}];
}); });
#else
ignore = a;
ignore = b;
ignore = c;
#endif // end of if (defined(__gfx1100__))
} }
} // namespace ck } // namespace ck
...@@ -173,11 +161,9 @@ int main(int, char*[]) ...@@ -173,11 +161,9 @@ int main(int, char*[])
// result check // result check
bool res = true; bool res = true;
bool res_swizzle_a = true; bool res_swizzle_a = true;
#if(defined(__gfx1100__))
res = ck::utils::check_err(wmma_c, host_c, "Error: Incorrect results!", 1e-2); res = ck::utils::check_err(wmma_c, host_c, "Error: Incorrect results!", 1e-2);
res_swizzle_a = res_swizzle_a =
ck::utils::check_err(wmma_c_swizzle_a, host_c, "Error: Incorrect results!", 1e-2); ck::utils::check_err(wmma_c_swizzle_a, host_c, "Error: Incorrect results!", 1e-2);
#endif // end of if (defined(__gfx1100__))
if(res && res_swizzle_a) if(res && res_swizzle_a)
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment