Unverified Commit bef7e52e authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Compatibility] Support CUDA 11.3 (#1290)

parent 9e67b861
...@@ -12,7 +12,11 @@ using cutlass::bfloat16_t; ...@@ -12,7 +12,11 @@ using cutlass::bfloat16_t;
using cutlass::half_t; using cutlass::half_t;
#define TL_DEVICE __forceinline__ __device__ #define TL_DEVICE __forceinline__ __device__
#define TL_NOT_IMPLEMENTED() \
{ \
printf("%s not implemented\n", __PRETTY_FUNCTION__); \
asm volatile("brkpt;\n"); \
}
template <typename T> struct normalize_atomic_type { template <typename T> struct normalize_atomic_type {
using type = T; using type = T;
}; };
...@@ -63,8 +67,12 @@ TL_DEVICE void AtomicMax(T1 &ref, T2 val, ...@@ -63,8 +67,12 @@ TL_DEVICE void AtomicMax(T1 &ref, T2 val,
} }
} }
} else { } else {
#if CUDART_VERSION >= 11080
cuda::atomic_ref<NT1, cuda::thread_scope_device> aref(*address); cuda::atomic_ref<NT1, cuda::thread_scope_device> aref(*address);
aref.fetch_max(cuda_cast<NT1>(val), cuda::memory_order(memory_order)); aref.fetch_max(cuda_cast<NT1>(val), cuda::memory_order(memory_order));
#else
TL_NOT_IMPLEMENTED();
#endif
} }
} }
...@@ -89,9 +97,13 @@ TL_DEVICE T1 AtomicMaxRet(T1 &ref, T2 val, ...@@ -89,9 +97,13 @@ TL_DEVICE T1 AtomicMaxRet(T1 &ref, T2 val,
} }
return static_cast<T1>(*reinterpret_cast<T1 *>(&old_val_ushort)); return static_cast<T1>(*reinterpret_cast<T1 *>(&old_val_ushort));
} else { } else {
#if CUDART_VERSION >= 11080
cuda::atomic_ref<NT1, cuda::thread_scope_device> aref(*address); cuda::atomic_ref<NT1, cuda::thread_scope_device> aref(*address);
return static_cast<T1>( return static_cast<T1>(
aref.fetch_max(cuda_cast<NT1>(val), cuda::memory_order(memory_order))); aref.fetch_max(cuda_cast<NT1>(val), cuda::memory_order(memory_order)));
#else
TL_NOT_IMPLEMENTED();
#endif
} }
} }
...@@ -117,8 +129,13 @@ TL_DEVICE void AtomicMin(T1 &ref, T2 val, ...@@ -117,8 +129,13 @@ TL_DEVICE void AtomicMin(T1 &ref, T2 val,
} }
} }
} else { } else {
#if CUDART_VERSION >= 11080
cuda::atomic_ref<NT1, cuda::thread_scope_device> aref(*address); cuda::atomic_ref<NT1, cuda::thread_scope_device> aref(*address);
aref.fetch_min(cuda_cast<NT1>(val), cuda::memory_order(memory_order)); return static_cast<T1>(
aref.fetch_min(cuda_cast<NT1>(val), cuda::memory_order(memory_order)));
#else
TL_NOT_IMPLEMENTED();
#endif
} }
} }
...@@ -143,9 +160,13 @@ TL_DEVICE T1 AtomicMinRet(T1 &ref, T2 val, ...@@ -143,9 +160,13 @@ TL_DEVICE T1 AtomicMinRet(T1 &ref, T2 val,
} }
return static_cast<T1>(*reinterpret_cast<T1 *>(&old_val_ushort)); return static_cast<T1>(*reinterpret_cast<T1 *>(&old_val_ushort));
} else { } else {
#if CUDART_VERSION >= 11080
cuda::atomic_ref<NT1, cuda::thread_scope_device> aref(*address); cuda::atomic_ref<NT1, cuda::thread_scope_device> aref(*address);
return static_cast<T1>( return static_cast<T1>(
aref.fetch_min(cuda_cast<NT1>(val), cuda::memory_order(memory_order))); aref.fetch_min(cuda_cast<NT1>(val), cuda::memory_order(memory_order)));
#else
TL_NOT_IMPLEMENTED();
#endif
} }
} }
...@@ -216,8 +237,12 @@ TL_DEVICE void AtomicAdd(T1 &ref, T2 val, ...@@ -216,8 +237,12 @@ TL_DEVICE void AtomicAdd(T1 &ref, T2 val,
} }
} }
} else { } else {
#if CUDART_VERSION >= 11080
cuda::atomic_ref<NT1, cuda::thread_scope_device> aref(*address); cuda::atomic_ref<NT1, cuda::thread_scope_device> aref(*address);
aref.fetch_add(cuda_cast<NT1>(val), cuda::memory_order(memory_order)); aref.fetch_add(cuda_cast<NT1>(val), cuda::memory_order(memory_order));
#else
TL_NOT_IMPLEMENTED();
#endif
} }
} }
...@@ -290,9 +315,13 @@ TL_DEVICE T1 AtomicAddRet(T1 &ref, T2 val, ...@@ -290,9 +315,13 @@ TL_DEVICE T1 AtomicAddRet(T1 &ref, T2 val,
} }
} }
} else { } else {
#if CUDART_VERSION >= 11080
cuda::atomic_ref<NT1, cuda::thread_scope_device> aref(*address); cuda::atomic_ref<NT1, cuda::thread_scope_device> aref(*address);
return static_cast<T1>( return static_cast<T1>(
aref.fetch_add(cuda_cast<NT1>(val), cuda::memory_order(memory_order))); aref.fetch_add(cuda_cast<NT1>(val), cuda::memory_order(memory_order)));
#else
TL_NOT_IMPLEMENTED();
#endif
} }
} }
...@@ -618,13 +647,21 @@ AtomicAddx4Ret(float *ref, float *val, ...@@ -618,13 +647,21 @@ AtomicAddx4Ret(float *ref, float *val,
#endif #endif
template <typename T> TL_DEVICE T AtomicLoad(T &ref, int memory_order) { template <typename T> TL_DEVICE T AtomicLoad(T &ref, int memory_order) {
#if CUDART_VERSION >= 11080
cuda::atomic_ref<T, cuda::thread_scope_device> aref(ref); cuda::atomic_ref<T, cuda::thread_scope_device> aref(ref);
return aref.load(cuda::memory_order(memory_order)); return aref.load(cuda::memory_order(memory_order));
#else
TL_NOT_IMPLEMENTED();
#endif
} }
template <typename T1, typename T2> template <typename T1, typename T2>
TL_DEVICE void AtomicStore(T1 &ref, T2 value, int memory_order) { TL_DEVICE void AtomicStore(T1 &ref, T2 value, int memory_order) {
using NT1 = typename normalize_atomic_type<T1>::type; using NT1 = typename normalize_atomic_type<T1>::type;
#if CUDART_VERSION >= 11080
cuda::atomic_ref<NT1, cuda::thread_scope_device> aref(ref); cuda::atomic_ref<NT1, cuda::thread_scope_device> aref(ref);
aref.store(cuda_cast<NT1>(value), cuda::memory_order(memory_order)); aref.store(cuda_cast<NT1>(value), cuda::memory_order(memory_order));
#else
TL_NOT_IMPLEMENTED();
#endif
} }
#pragma once #pragma once
#if __CUDA_ARCH_LIST__ >= 890
#include "./cuda_fp8.h" #include "./cuda_fp8.h"
#endif
#include "common.h" #include "common.h"
#ifndef __CUDACC_RTC__ #ifndef __CUDACC_RTC__
...@@ -117,6 +120,7 @@ __device__ void debug_print_var<double>(const char *msg, double var) { ...@@ -117,6 +120,7 @@ __device__ void debug_print_var<double>(const char *msg, double var) {
threadIdx.z, var); threadIdx.z, var);
} }
#if __CUDA_ARCH_LIST__ >= 890
// Specialization for fp8_e4_t type // Specialization for fp8_e4_t type
template <> template <>
__device__ void debug_print_var<fp8_e4_t>(const char *msg, fp8_e4_t var) { __device__ void debug_print_var<fp8_e4_t>(const char *msg, fp8_e4_t var) {
...@@ -137,6 +141,8 @@ __device__ void debug_print_var<fp8_e5_t>(const char *msg, fp8_e5_t var) { ...@@ -137,6 +141,8 @@ __device__ void debug_print_var<fp8_e5_t>(const char *msg, fp8_e5_t var) {
threadIdx.z, (float)var); threadIdx.z, (float)var);
} }
#endif
// Template declaration for device-side debug printing (buffer only) // Template declaration for device-side debug printing (buffer only)
template <typename T> template <typename T>
__device__ void debug_print_buffer_value(const char *msg, const char *buf_name, __device__ void debug_print_buffer_value(const char *msg, const char *buf_name,
...@@ -242,6 +248,7 @@ __device__ void debug_print_buffer_value<double>(const char *msg, ...@@ -242,6 +248,7 @@ __device__ void debug_print_buffer_value<double>(const char *msg,
} }
// Specialization for fp8_e4_t type // Specialization for fp8_e4_t type
#if __CUDA_ARCH_LIST__ >= 890
template <> template <>
__device__ void debug_print_buffer_value<fp8_e4_t>(const char *msg, __device__ void debug_print_buffer_value<fp8_e4_t>(const char *msg,
const char *buf_name, const char *buf_name,
...@@ -263,6 +270,8 @@ __device__ void debug_print_buffer_value<fp8_e5_t>(const char *msg, ...@@ -263,6 +270,8 @@ __device__ void debug_print_buffer_value<fp8_e5_t>(const char *msg,
threadIdx.z, buf_name, index, (float)var); threadIdx.z, buf_name, index, (float)var);
} }
#endif
// Specialization for int16 type // Specialization for int16 type
template <> template <>
__device__ void debug_print_buffer_value<int16_t>(const char *msg, __device__ void debug_print_buffer_value<int16_t>(const char *msg,
......
...@@ -8,7 +8,6 @@ ...@@ -8,7 +8,6 @@
#include <cute/underscore.hpp> #include <cute/underscore.hpp>
#include "common.h" #include "common.h"
#include "cuda_fp8.h"
#include "intrin.h" #include "intrin.h"
namespace cute::tl_mma { namespace cute::tl_mma {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment