Commit 54241df6 authored by fengzch's avatar fengzch
Browse files

fix: compile gemm_w8a8.cu complete

parent 2cb9a2c7
...@@ -21,14 +21,14 @@ __device__ __forceinline__ static T load(const T *addr) { ...@@ -21,14 +21,14 @@ __device__ __forceinline__ static T load(const T *addr) {
uint2 data; uint2 data;
asm volatile("ld.shared.v2.b32 {%0, %1}, [%2];" asm volatile("ld.shared.v2.b32 {%0, %1}, [%2];"
: "=r"(data.x), "=r"(data.y) : "=r"(data.x), "=r"(data.y)
: "l"(__cvta_generic_to_shared(addr))); : "l"((addr)));
return *reinterpret_cast<T *>(&data); return *reinterpret_cast<T *>(&data);
} }
if constexpr (sizeof(T) == 16) { if constexpr (sizeof(T) == 16) {
uint4 data; uint4 data;
asm volatile("ld.shared.v4.b32 {%0, %1, %2, %3}, [%4];" asm volatile("ld.shared.v4.b32 {%0, %1, %2, %3}, [%4];"
: "=r"(data.x), "=r"(data.y), "=r"(data.z), "=r"(data.w) : "=r"(data.x), "=r"(data.y), "=r"(data.z), "=r"(data.w)
: "l"(__cvta_generic_to_shared(addr))); : "l"((addr)));
return *reinterpret_cast<T *>(&data); return *reinterpret_cast<T *>(&data);
} }
return *addr; return *addr;
...@@ -89,12 +89,12 @@ __device__ __forceinline__ static void store(T *addr, T val) { ...@@ -89,12 +89,12 @@ __device__ __forceinline__ static void store(T *addr, T val) {
if constexpr (sizeof(T) == 8) { if constexpr (sizeof(T) == 8) {
uint2 data = *reinterpret_cast<uint2 *>(&val); uint2 data = *reinterpret_cast<uint2 *>(&val);
asm volatile( asm volatile(
"st.shared.v2.b32 [%0], {%1, %2};" ::"l"(__cvta_generic_to_shared(addr)), "r"(data.x), "r"(data.y)); "st.shared.v2.b32 [%0], {%1, %2};" ::"l"((addr)), "r"(data.x), "r"(data.y));
return; return;
} }
if constexpr (sizeof(T) == 16) { if constexpr (sizeof(T) == 16) {
uint4 data = *reinterpret_cast<uint4 *>(&val); uint4 data = *reinterpret_cast<uint4 *>(&val);
asm volatile("st.shared.v4.b32 [%0], {%1, %2, %3, %4};" ::"l"(__cvta_generic_to_shared(addr)), asm volatile("st.shared.v4.b32 [%0], {%1, %2, %3, %4};" ::"l"((addr)),
"r"(data.x), "r"(data.x),
"r"(data.y), "r"(data.y),
"r"(data.z), "r"(data.z),
...@@ -192,9 +192,9 @@ __device__ __forceinline__ static void unused_var(T &val, bool alwaysfalse) { ...@@ -192,9 +192,9 @@ __device__ __forceinline__ static void unused_var(T &val, bool alwaysfalse) {
} }
__device__ __forceinline__ static void ldmatrix(const void *ptr, uint4 &out) { __device__ __forceinline__ static void ldmatrix(const void *ptr, uint4 &out) {
// asm volatile("ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, %3}, [%4];\n" asm volatile("ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, %3}, [%4];\n"
// : "=r"(out.x), "=r"(out.y), "=r"(out.z), "=r"(out.w) : "=r"(out.x), "=r"(out.y), "=r"(out.z), "=r"(out.w)
// : "l"(__cvta_generic_to_shared(ptr))); // limengmeng : "l"((ptr))); // limengmeng
} }
template<typename T> template<typename T>
......
...@@ -26,7 +26,7 @@ void quantize_w8a8_act(Tensor input, Tensor output, Tensor oscales, bool fuse_gl ...@@ -26,7 +26,7 @@ void quantize_w8a8_act(Tensor input, Tensor output, Tensor oscales, bool fuse_gl
auto func = auto func =
invoke_kernel<kernel, const GEMM::half_t *, GEMM::packed_act_t *, GEMM::packed_ascale_t *, int, bool>; invoke_kernel<kernel, const GEMM::half_t *, GEMM::packed_act_t *, GEMM::packed_ascale_t *, int, bool>;
checkCUDA(cudaFuncSetAttribute(func, cudaFuncAttributeMaxDynamicSharedMemorySize, 92160)); checkCUDA(cudaFuncSetAttribute(reinterpret_cast<const void*>(func), cudaFuncAttributeMaxDynamicSharedMemorySize, 92160));
func<<<grid, block, kernel::smemSize(M, K)>>>(input.data_ptr<GEMM::half_t>(), func<<<grid, block, kernel::smemSize(M, K)>>>(input.data_ptr<GEMM::half_t>(),
output.data_ptr<GEMM::packed_act_t>(), output.data_ptr<GEMM::packed_act_t>(),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment