Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
tilelang
Commits
bef7e52e
"src/sdk/git@developer.sourcefind.cn:OpenDAS/nni.git" did not exist on "c3cd9fe7ffdb20d07f7562592774fe071b235de3"
Unverified
Commit
bef7e52e
authored
Nov 20, 2025
by
Lei Wang
Committed by
GitHub
Nov 20, 2025
Browse files
[Compatibility] Support CUDA 11.3 (#1290)
parent
9e67b861
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
48 additions
and
3 deletions
+48
-3
src/tl_templates/cuda/atomic.h
src/tl_templates/cuda/atomic.h
+39
-2
src/tl_templates/cuda/debug.h
src/tl_templates/cuda/debug.h
+9
-0
src/tl_templates/cuda/gemm_mma.h
src/tl_templates/cuda/gemm_mma.h
+0
-1
No files found.
src/tl_templates/cuda/atomic.h
View file @
bef7e52e
...
@@ -12,7 +12,11 @@ using cutlass::bfloat16_t;
...
@@ -12,7 +12,11 @@ using cutlass::bfloat16_t;
using
cutlass
::
half_t
;
using
cutlass
::
half_t
;
#define TL_DEVICE __forceinline__ __device__
#define TL_DEVICE __forceinline__ __device__
#define TL_NOT_IMPLEMENTED() \
{ \
printf("%s not implemented\n", __PRETTY_FUNCTION__); \
asm volatile("brkpt;\n"); \
}
template
<
typename
T
>
struct
normalize_atomic_type
{
template
<
typename
T
>
struct
normalize_atomic_type
{
using
type
=
T
;
using
type
=
T
;
};
};
...
@@ -63,8 +67,12 @@ TL_DEVICE void AtomicMax(T1 &ref, T2 val,
...
@@ -63,8 +67,12 @@ TL_DEVICE void AtomicMax(T1 &ref, T2 val,
}
}
}
}
}
else
{
}
else
{
#if CUDART_VERSION >= 11080
cuda
::
atomic_ref
<
NT1
,
cuda
::
thread_scope_device
>
aref
(
*
address
);
cuda
::
atomic_ref
<
NT1
,
cuda
::
thread_scope_device
>
aref
(
*
address
);
aref
.
fetch_max
(
cuda_cast
<
NT1
>
(
val
),
cuda
::
memory_order
(
memory_order
));
aref
.
fetch_max
(
cuda_cast
<
NT1
>
(
val
),
cuda
::
memory_order
(
memory_order
));
#else
TL_NOT_IMPLEMENTED
();
#endif
}
}
}
}
...
@@ -89,9 +97,13 @@ TL_DEVICE T1 AtomicMaxRet(T1 &ref, T2 val,
...
@@ -89,9 +97,13 @@ TL_DEVICE T1 AtomicMaxRet(T1 &ref, T2 val,
}
}
return
static_cast
<
T1
>
(
*
reinterpret_cast
<
T1
*>
(
&
old_val_ushort
));
return
static_cast
<
T1
>
(
*
reinterpret_cast
<
T1
*>
(
&
old_val_ushort
));
}
else
{
}
else
{
#if CUDART_VERSION >= 11080
cuda
::
atomic_ref
<
NT1
,
cuda
::
thread_scope_device
>
aref
(
*
address
);
cuda
::
atomic_ref
<
NT1
,
cuda
::
thread_scope_device
>
aref
(
*
address
);
return
static_cast
<
T1
>
(
return
static_cast
<
T1
>
(
aref
.
fetch_max
(
cuda_cast
<
NT1
>
(
val
),
cuda
::
memory_order
(
memory_order
)));
aref
.
fetch_max
(
cuda_cast
<
NT1
>
(
val
),
cuda
::
memory_order
(
memory_order
)));
#else
TL_NOT_IMPLEMENTED
();
#endif
}
}
}
}
...
@@ -117,8 +129,13 @@ TL_DEVICE void AtomicMin(T1 &ref, T2 val,
...
@@ -117,8 +129,13 @@ TL_DEVICE void AtomicMin(T1 &ref, T2 val,
}
}
}
}
}
else
{
}
else
{
#if CUDART_VERSION >= 11080
cuda
::
atomic_ref
<
NT1
,
cuda
::
thread_scope_device
>
aref
(
*
address
);
cuda
::
atomic_ref
<
NT1
,
cuda
::
thread_scope_device
>
aref
(
*
address
);
aref
.
fetch_min
(
cuda_cast
<
NT1
>
(
val
),
cuda
::
memory_order
(
memory_order
));
return
static_cast
<
T1
>
(
aref
.
fetch_min
(
cuda_cast
<
NT1
>
(
val
),
cuda
::
memory_order
(
memory_order
)));
#else
TL_NOT_IMPLEMENTED
();
#endif
}
}
}
}
...
@@ -143,9 +160,13 @@ TL_DEVICE T1 AtomicMinRet(T1 &ref, T2 val,
...
@@ -143,9 +160,13 @@ TL_DEVICE T1 AtomicMinRet(T1 &ref, T2 val,
}
}
return
static_cast
<
T1
>
(
*
reinterpret_cast
<
T1
*>
(
&
old_val_ushort
));
return
static_cast
<
T1
>
(
*
reinterpret_cast
<
T1
*>
(
&
old_val_ushort
));
}
else
{
}
else
{
#if CUDART_VERSION >= 11080
cuda
::
atomic_ref
<
NT1
,
cuda
::
thread_scope_device
>
aref
(
*
address
);
cuda
::
atomic_ref
<
NT1
,
cuda
::
thread_scope_device
>
aref
(
*
address
);
return
static_cast
<
T1
>
(
return
static_cast
<
T1
>
(
aref
.
fetch_min
(
cuda_cast
<
NT1
>
(
val
),
cuda
::
memory_order
(
memory_order
)));
aref
.
fetch_min
(
cuda_cast
<
NT1
>
(
val
),
cuda
::
memory_order
(
memory_order
)));
#else
TL_NOT_IMPLEMENTED
();
#endif
}
}
}
}
...
@@ -216,8 +237,12 @@ TL_DEVICE void AtomicAdd(T1 &ref, T2 val,
...
@@ -216,8 +237,12 @@ TL_DEVICE void AtomicAdd(T1 &ref, T2 val,
}
}
}
}
}
else
{
}
else
{
#if CUDART_VERSION >= 11080
cuda
::
atomic_ref
<
NT1
,
cuda
::
thread_scope_device
>
aref
(
*
address
);
cuda
::
atomic_ref
<
NT1
,
cuda
::
thread_scope_device
>
aref
(
*
address
);
aref
.
fetch_add
(
cuda_cast
<
NT1
>
(
val
),
cuda
::
memory_order
(
memory_order
));
aref
.
fetch_add
(
cuda_cast
<
NT1
>
(
val
),
cuda
::
memory_order
(
memory_order
));
#else
TL_NOT_IMPLEMENTED
();
#endif
}
}
}
}
...
@@ -290,9 +315,13 @@ TL_DEVICE T1 AtomicAddRet(T1 &ref, T2 val,
...
@@ -290,9 +315,13 @@ TL_DEVICE T1 AtomicAddRet(T1 &ref, T2 val,
}
}
}
}
}
else
{
}
else
{
#if CUDART_VERSION >= 11080
cuda
::
atomic_ref
<
NT1
,
cuda
::
thread_scope_device
>
aref
(
*
address
);
cuda
::
atomic_ref
<
NT1
,
cuda
::
thread_scope_device
>
aref
(
*
address
);
return
static_cast
<
T1
>
(
return
static_cast
<
T1
>
(
aref
.
fetch_add
(
cuda_cast
<
NT1
>
(
val
),
cuda
::
memory_order
(
memory_order
)));
aref
.
fetch_add
(
cuda_cast
<
NT1
>
(
val
),
cuda
::
memory_order
(
memory_order
)));
#else
TL_NOT_IMPLEMENTED
();
#endif
}
}
}
}
...
@@ -618,13 +647,21 @@ AtomicAddx4Ret(float *ref, float *val,
...
@@ -618,13 +647,21 @@ AtomicAddx4Ret(float *ref, float *val,
#endif
#endif
template
<
typename
T
>
TL_DEVICE
T
AtomicLoad
(
T
&
ref
,
int
memory_order
)
{
template
<
typename
T
>
TL_DEVICE
T
AtomicLoad
(
T
&
ref
,
int
memory_order
)
{
#if CUDART_VERSION >= 11080
cuda
::
atomic_ref
<
T
,
cuda
::
thread_scope_device
>
aref
(
ref
);
cuda
::
atomic_ref
<
T
,
cuda
::
thread_scope_device
>
aref
(
ref
);
return
aref
.
load
(
cuda
::
memory_order
(
memory_order
));
return
aref
.
load
(
cuda
::
memory_order
(
memory_order
));
#else
TL_NOT_IMPLEMENTED
();
#endif
}
}
template
<
typename
T1
,
typename
T2
>
template
<
typename
T1
,
typename
T2
>
TL_DEVICE
void
AtomicStore
(
T1
&
ref
,
T2
value
,
int
memory_order
)
{
TL_DEVICE
void
AtomicStore
(
T1
&
ref
,
T2
value
,
int
memory_order
)
{
using
NT1
=
typename
normalize_atomic_type
<
T1
>::
type
;
using
NT1
=
typename
normalize_atomic_type
<
T1
>::
type
;
#if CUDART_VERSION >= 11080
cuda
::
atomic_ref
<
NT1
,
cuda
::
thread_scope_device
>
aref
(
ref
);
cuda
::
atomic_ref
<
NT1
,
cuda
::
thread_scope_device
>
aref
(
ref
);
aref
.
store
(
cuda_cast
<
NT1
>
(
value
),
cuda
::
memory_order
(
memory_order
));
aref
.
store
(
cuda_cast
<
NT1
>
(
value
),
cuda
::
memory_order
(
memory_order
));
#else
TL_NOT_IMPLEMENTED
();
#endif
}
}
src/tl_templates/cuda/debug.h
View file @
bef7e52e
#pragma once
#pragma once
#if __CUDA_ARCH_LIST__ >= 890
#include "./cuda_fp8.h"
#include "./cuda_fp8.h"
#endif
#include "common.h"
#include "common.h"
#ifndef __CUDACC_RTC__
#ifndef __CUDACC_RTC__
...
@@ -117,6 +120,7 @@ __device__ void debug_print_var<double>(const char *msg, double var) {
...
@@ -117,6 +120,7 @@ __device__ void debug_print_var<double>(const char *msg, double var) {
threadIdx
.
z
,
var
);
threadIdx
.
z
,
var
);
}
}
#if __CUDA_ARCH_LIST__ >= 890
// Specialization for fp8_e4_t type
// Specialization for fp8_e4_t type
template
<
>
template
<
>
__device__
void
debug_print_var
<
fp8_e4_t
>
(
const
char
*
msg
,
fp8_e4_t
var
)
{
__device__
void
debug_print_var
<
fp8_e4_t
>
(
const
char
*
msg
,
fp8_e4_t
var
)
{
...
@@ -137,6 +141,8 @@ __device__ void debug_print_var<fp8_e5_t>(const char *msg, fp8_e5_t var) {
...
@@ -137,6 +141,8 @@ __device__ void debug_print_var<fp8_e5_t>(const char *msg, fp8_e5_t var) {
threadIdx
.
z
,
(
float
)
var
);
threadIdx
.
z
,
(
float
)
var
);
}
}
#endif
// Template declaration for device-side debug printing (buffer only)
// Template declaration for device-side debug printing (buffer only)
template
<
typename
T
>
template
<
typename
T
>
__device__
void
debug_print_buffer_value
(
const
char
*
msg
,
const
char
*
buf_name
,
__device__
void
debug_print_buffer_value
(
const
char
*
msg
,
const
char
*
buf_name
,
...
@@ -242,6 +248,7 @@ __device__ void debug_print_buffer_value<double>(const char *msg,
...
@@ -242,6 +248,7 @@ __device__ void debug_print_buffer_value<double>(const char *msg,
}
}
// Specialization for fp8_e4_t type
// Specialization for fp8_e4_t type
#if __CUDA_ARCH_LIST__ >= 890
template
<
>
template
<
>
__device__
void
debug_print_buffer_value
<
fp8_e4_t
>
(
const
char
*
msg
,
__device__
void
debug_print_buffer_value
<
fp8_e4_t
>
(
const
char
*
msg
,
const
char
*
buf_name
,
const
char
*
buf_name
,
...
@@ -263,6 +270,8 @@ __device__ void debug_print_buffer_value<fp8_e5_t>(const char *msg,
...
@@ -263,6 +270,8 @@ __device__ void debug_print_buffer_value<fp8_e5_t>(const char *msg,
threadIdx
.
z
,
buf_name
,
index
,
(
float
)
var
);
threadIdx
.
z
,
buf_name
,
index
,
(
float
)
var
);
}
}
#endif
// Specialization for int16 type
// Specialization for int16 type
template
<
>
template
<
>
__device__
void
debug_print_buffer_value
<
int16_t
>
(
const
char
*
msg
,
__device__
void
debug_print_buffer_value
<
int16_t
>
(
const
char
*
msg
,
...
...
src/tl_templates/cuda/gemm_mma.h
View file @
bef7e52e
...
@@ -8,7 +8,6 @@
...
@@ -8,7 +8,6 @@
#include <cute/underscore.hpp>
#include <cute/underscore.hpp>
#include "common.h"
#include "common.h"
#include "cuda_fp8.h"
#include "intrin.h"
#include "intrin.h"
namespace
cute
::
tl_mma
{
namespace
cute
::
tl_mma
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment