Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
469e903b
Commit
469e903b
authored
Mar 28, 2025
by
zhuwenwen
Browse files
Merge tag 'v0.8.2' into v0.8.2-dev
parents
389ebcf7
25f560a6
Changes
535
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
896 additions
and
198 deletions
+896
-198
csrc/core/math.hpp
csrc/core/math.hpp
+0
-5
csrc/cpu/attention.cpp
csrc/cpu/attention.cpp
+2
-2
csrc/cpu/cache.cpp
csrc/cpu/cache.cpp
+21
-17
csrc/cpu/cpu_types.hpp
csrc/cpu/cpu_types.hpp
+3
-0
csrc/cpu/cpu_types_arm.hpp
csrc/cpu/cpu_types_arm.hpp
+4
-0
csrc/cpu/cpu_types_vxe.hpp
csrc/cpu/cpu_types_vxe.hpp
+480
-0
csrc/cpu/cpu_types_x86.hpp
csrc/cpu/cpu_types_x86.hpp
+9
-0
csrc/cpu/pos_encoding.cpp
csrc/cpu/pos_encoding.cpp
+1
-1
csrc/cpu/quant.cpp
csrc/cpu/quant.cpp
+1
-1
csrc/cuda_utils.h
csrc/cuda_utils.h
+18
-4
csrc/custom_all_reduce.cu
csrc/custom_all_reduce.cu
+41
-0
csrc/custom_all_reduce.cuh
csrc/custom_all_reduce.cuh
+140
-50
csrc/custom_all_reduce_test.cu
csrc/custom_all_reduce_test.cu
+46
-29
csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c2x.hpp
csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c2x.hpp
+14
-12
csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp
csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp
+38
-38
csrc/cutlass_extensions/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8_blockwise_scaling.hpp
...mma_tma_gmma_ss_warpspecialized_fp8_blockwise_scaling.hpp
+1
-1
csrc/cutlass_extensions/vllm_cutlass_library_extension.py
csrc/cutlass_extensions/vllm_cutlass_library_extension.py
+7
-7
csrc/dispatch_utils.h
csrc/dispatch_utils.h
+26
-6
csrc/layernorm_quant_kernels.cu
csrc/layernorm_quant_kernels.cu
+33
-25
csrc/moe/moe_ops.h
csrc/moe/moe_ops.h
+11
-0
No files found.
Too many changes to show.
To preserve performance only
535 of 535+
files are displayed.
Plain diff
Email patch
csrc/core/math.hpp
View file @
469e903b
...
@@ -7,8 +7,3 @@ inline constexpr uint32_t next_pow_2(uint32_t const num) {
...
@@ -7,8 +7,3 @@ inline constexpr uint32_t next_pow_2(uint32_t const num) {
if
(
num
<=
1
)
return
num
;
if
(
num
<=
1
)
return
num
;
return
1
<<
(
CHAR_BIT
*
sizeof
(
num
)
-
__builtin_clz
(
num
-
1
));
return
1
<<
(
CHAR_BIT
*
sizeof
(
num
)
-
__builtin_clz
(
num
-
1
));
}
}
template
<
typename
T
>
inline
constexpr
std
::
enable_if_t
<
std
::
is_integral_v
<
T
>
,
T
>
ceil_div
(
T
a
,
T
b
)
{
return
(
a
+
b
-
1
)
/
b
;
}
\ No newline at end of file
csrc/cpu/attention.cpp
View file @
469e903b
...
@@ -24,8 +24,8 @@ struct KernelVecType<float> {
...
@@ -24,8 +24,8 @@ struct KernelVecType<float> {
template
<
>
template
<
>
struct
KernelVecType
<
c10
::
Half
>
{
struct
KernelVecType
<
c10
::
Half
>
{
#ifdef
__powerpc64__
#if
def
ined(
__powerpc64__
) || defined(__s390x__)
// Power architecture-specific vector types
// Power
and s390x
architecture-specific vector types
using
q_load_vec_type
=
vec_op
::
FP32Vec8
;
using
q_load_vec_type
=
vec_op
::
FP32Vec8
;
using
k_load_vec_type
=
vec_op
::
FP32Vec16
;
using
k_load_vec_type
=
vec_op
::
FP32Vec16
;
using
v_load_vec_type
=
vec_op
::
FP32Vec16
;
using
v_load_vec_type
=
vec_op
::
FP32Vec16
;
...
...
csrc/cpu/cache.cpp
View file @
469e903b
...
@@ -3,6 +3,12 @@
...
@@ -3,6 +3,12 @@
#include "cpu_types.hpp"
#include "cpu_types.hpp"
#if defined(__x86_64__)
#define DISPATCH_MACRO VLLM_DISPATCH_FLOATING_TYPES_WITH_E5M2
#else
#define DISPATCH_MACRO VLLM_DISPATCH_FLOATING_TYPES
#endif
namespace
{
namespace
{
template
<
typename
scalar_t
>
template
<
typename
scalar_t
>
void
copy_blocks_cpu_impl
(
std
::
vector
<
torch
::
Tensor
>
const
&
key_caches
,
void
copy_blocks_cpu_impl
(
std
::
vector
<
torch
::
Tensor
>
const
&
key_caches
,
...
@@ -95,13 +101,12 @@ void copy_blocks(std::vector<torch::Tensor> const& key_caches,
...
@@ -95,13 +101,12 @@ void copy_blocks(std::vector<torch::Tensor> const& key_caches,
}
}
const
int
element_num_per_block
=
key_caches
[
0
][
0
].
numel
();
const
int
element_num_per_block
=
key_caches
[
0
][
0
].
numel
();
VLLM_DISPATCH_FLOATING_TYPES
(
DISPATCH_MACRO
(
key_caches
[
0
].
scalar_type
(),
"copy_blocks_cpu_impl"
,
[
&
]
{
key_caches
[
0
].
scalar_type
(),
"copy_blocks_cpu_impl"
,
[
&
]
{
CPU_KERNEL_GUARD_IN
(
copy_blocks_cpu_impl
)
CPU_KERNEL_GUARD_IN
(
copy_blocks_cpu_impl
)
copy_blocks_cpu_impl
<
scalar_t
>
(
key_caches
,
value_caches
,
block_mapping
,
copy_blocks_cpu_impl
<
scalar_t
>
(
key_caches
,
value_caches
,
block_mapping
,
element_num_per_block
,
num_layers
);
element_num_per_block
,
num_layers
);
CPU_KERNEL_GUARD_OUT
(
copy_blocks_cpu_impl
)
CPU_KERNEL_GUARD_OUT
(
copy_blocks_cpu_impl
)
});
});
}
}
void
reshape_and_cache
(
torch
::
Tensor
&
key
,
torch
::
Tensor
&
value
,
void
reshape_and_cache
(
torch
::
Tensor
&
key
,
torch
::
Tensor
&
value
,
...
@@ -118,16 +123,15 @@ void reshape_and_cache(torch::Tensor& key, torch::Tensor& value,
...
@@ -118,16 +123,15 @@ void reshape_and_cache(torch::Tensor& key, torch::Tensor& value,
int
key_stride
=
key
.
stride
(
0
);
int
key_stride
=
key
.
stride
(
0
);
int
value_stride
=
value
.
stride
(
0
);
int
value_stride
=
value
.
stride
(
0
);
VLLM_DISPATCH_FLOATING_TYPES
(
DISPATCH_MACRO
(
key
.
scalar_type
(),
"reshape_and_cache_cpu_impl"
,
[
&
]
{
key
.
scalar_type
(),
"reshape_and_cache_cpu_impl"
,
[
&
]
{
CPU_KERNEL_GUARD_IN
(
reshape_and_cache_cpu_impl
)
CPU_KERNEL_GUARD_IN
(
reshape_and_cache_cpu_impl
)
reshape_and_cache_cpu_impl
<
scalar_t
>
(
reshape_and_cache_cpu_impl
<
scalar_t
>
(
key
.
data_ptr
<
scalar_t
>
(),
value
.
data_ptr
<
scalar_t
>
(),
key
.
data_ptr
<
scalar_t
>
(),
value
.
data_ptr
<
scalar_t
>
(),
key_cache
.
data_ptr
<
scalar_t
>
(),
value_cache
.
data_ptr
<
scalar_t
>
(),
key_cache
.
data_ptr
<
scalar_t
>
(),
value_cache
.
data_ptr
<
scalar_t
>
(),
slot_mapping
.
data_ptr
<
int64_t
>
(),
num_tokens
,
key_stride
,
value_stride
,
slot_mapping
.
data_ptr
<
int64_t
>
(),
num_tokens
,
key_stride
,
num_heads
,
head_size
,
block_size
,
x
);
value_stride
,
num_heads
,
head_size
,
block_size
,
x
);
CPU_KERNEL_GUARD_OUT
(
reshape_and_cache_cpu_impl
)
CPU_KERNEL_GUARD_OUT
(
reshape_and_cache_cpu_impl
)
});
});
}
}
void
swap_blocks
(
torch
::
Tensor
&
src
,
torch
::
Tensor
&
dst
,
void
swap_blocks
(
torch
::
Tensor
&
src
,
torch
::
Tensor
&
dst
,
...
...
csrc/cpu/cpu_types.hpp
View file @
469e903b
...
@@ -7,6 +7,9 @@
...
@@ -7,6 +7,9 @@
#elif defined(__POWER9_VECTOR__)
#elif defined(__POWER9_VECTOR__)
// ppc implementation
// ppc implementation
#include "cpu_types_vsx.hpp"
#include "cpu_types_vsx.hpp"
#elif defined(__s390x__)
// s390 implementation
#include "cpu_types_vxe.hpp"
#elif defined(__aarch64__)
#elif defined(__aarch64__)
// arm implementation
// arm implementation
#include "cpu_types_arm.hpp"
#include "cpu_types_arm.hpp"
...
...
csrc/cpu/cpu_types_arm.hpp
View file @
469e903b
...
@@ -2,6 +2,10 @@
...
@@ -2,6 +2,10 @@
#include <torch/all.h>
#include <torch/all.h>
#include <cmath>
#include <cmath>
#if defined(__APPLE__)
#include "omp.h"
#endif
namespace
vec_op
{
namespace
vec_op
{
#ifdef ARM_BF16_SUPPORT
#ifdef ARM_BF16_SUPPORT
...
...
csrc/cpu/cpu_types_vxe.hpp
0 → 100644
View file @
469e903b
#ifndef CPU_TYPES_VXE_HPP
#define CPU_TYPES_VXE_HPP
#include <vecintrin.h>
#include <cmath>
#include <torch/all.h>
namespace
vec_op
{
#define vec_neg(a) (-(a))
#define vec_add(a, b) ((a) + (b))
#define vec_sub(a, b) ((a) - (b))
#define vec_mul(a, b) ((a) * (b))
#define vec_div(a, b) ((a) / (b))
#define vec_sr(a, b) ((a) >> (b)) // Vector Shift Right Algebaic
#define vec_sl(a, b) ((a) << (b)) // Vector Shift Left
// FIXME: FP16 is not fully supported in Torch-CPU
#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__)
#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))
#ifndef CPU_OP_GUARD
#define CPU_KERNEL_GUARD_IN(NAME)
#define CPU_KERNEL_GUARD_OUT(NAME)
#else
#define CPU_KERNEL_GUARD_IN(NAME) \
std::cout << #NAME << " invoked." << std::endl;
#define CPU_KERNEL_GUARD_OUT(NAME) \
std::cout << #NAME << " exit." << std::endl;
#endif
#define FORCE_INLINE __attribute__((always_inline)) inline
namespace
{
template
<
typename
T
,
T
...
indexes
,
typename
F
>
constexpr
void
unroll_loop_item
(
std
::
integer_sequence
<
T
,
indexes
...
>
,
F
&&
f
)
{
(
f
(
std
::
integral_constant
<
T
,
indexes
>
{}),
...);
}
};
// namespace
template
<
typename
T
,
T
count
,
typename
F
,
typename
=
std
::
enable_if_t
<
std
::
is_invocable_v
<
F
,
T
>
>>
constexpr
void
unroll_loop
(
F
&&
f
)
{
unroll_loop_item
(
std
::
make_integer_sequence
<
T
,
count
>
{},
std
::
forward
<
F
>
(
f
));
}
template
<
typename
T
>
struct
Vec
{
constexpr
static
int
get_elem_num
()
{
return
T
::
VEC_ELEM_NUM
;
}
};
typedef
struct
ss16x8x2_t
{
__vector
signed
short
val
[
2
];
}
ss16x8x2_t
;
typedef
struct
ss16x8x4_t
{
__vector
signed
short
val
[
4
];
}
ss16x8x4_t
;
typedef
struct
f32x4x2_t
{
__vector
float
val
[
2
];
}
f32x4x2_t
;
typedef
struct
f32x4x4_t
{
__vector
float
val
[
4
];
}
f32x4x4_t
;
struct
FP32Vec8
;
struct
FP32Vec16
;
struct
BF16Vec8
:
public
Vec
<
BF16Vec8
>
{
constexpr
static
int
VEC_ELEM_NUM
=
8
;
__vector
signed
short
reg
;
explicit
BF16Vec8
(
const
void
*
ptr
)
:
reg
(
*
(
__vector
signed
short
*
)
ptr
)
{}
explicit
BF16Vec8
(
const
FP32Vec8
&
);
void
save
(
void
*
ptr
)
const
{
*
reinterpret_cast
<
__vector
signed
short
*>
(
ptr
)
=
reg
;
}
};
struct
BF16Vec16
:
public
Vec
<
BF16Vec16
>
{
constexpr
static
int
VEC_ELEM_NUM
=
16
;
ss16x8x2_t
reg
;
explicit
BF16Vec16
(
const
void
*
ptr
)
{
// Load 256 bits in two parts
reg
.
val
[
0
]
=
(
__vector
signed
short
)
vec_xl
(
0
,
(
signed
short
*
)
ptr
);
reg
.
val
[
1
]
=
(
__vector
signed
short
)
vec_xl
(
16
,
(
signed
short
*
)
ptr
);
}
explicit
BF16Vec16
(
const
FP32Vec16
&
);
void
save
(
void
*
ptr
)
const
{
// Save 256 bits in two parts
vec_xst
(
reg
.
val
[
0
],
0
,
(
signed
short
*
)
ptr
);
vec_xst
(
reg
.
val
[
1
],
16
,
(
signed
short
*
)
ptr
);
}
};
const
static
__vector
signed
short
zero
=
vec_splats
((
signed
short
)
0
);
struct
BF16Vec32
:
public
Vec
<
BF16Vec32
>
{
constexpr
static
int
VEC_ELEM_NUM
=
32
;
ss16x8x4_t
reg
;
explicit
BF16Vec32
(
const
void
*
ptr
)
:
reg
(
*
reinterpret_cast
<
const
ss16x8x4_t
*>
(
ptr
))
{}
explicit
BF16Vec32
(
ss16x8x4_t
data
)
:
reg
(
data
)
{}
explicit
BF16Vec32
(
const
BF16Vec8
&
vec8_data
)
:
reg
({
vec8_data
.
reg
,
vec8_data
.
reg
,
vec8_data
.
reg
,
vec8_data
.
reg
})
{}
void
save
(
void
*
ptr
)
const
{
*
reinterpret_cast
<
ss16x8x4_t
*>
(
ptr
)
=
reg
;
}
};
struct
FP32Vec4
:
public
Vec
<
FP32Vec4
>
{
constexpr
static
int
VEC_ELEM_NUM
=
4
;
union
AliasReg
{
__vector
float
reg
;
float
values
[
VEC_ELEM_NUM
];
};
__vector
float
reg
;
explicit
FP32Vec4
(
float
v
)
:
reg
(
vec_splats
(
v
))
{}
explicit
FP32Vec4
()
:
reg
(
vec_splats
(
0.0
f
))
{}
explicit
FP32Vec4
(
const
float
*
ptr
)
:
reg
(
vec_xl
(
0
,
ptr
))
{}
explicit
FP32Vec4
(
__vector
float
data
)
:
reg
(
data
)
{}
explicit
FP32Vec4
(
const
FP32Vec4
&
data
)
:
reg
(
data
.
reg
)
{}
};
struct
FP32Vec8
:
public
Vec
<
FP32Vec8
>
{
constexpr
static
int
VEC_ELEM_NUM
=
8
;
union
AliasReg
{
f32x4x2_t
reg
;
float
values
[
VEC_ELEM_NUM
];
};
f32x4x2_t
reg
;
explicit
FP32Vec8
(
float
v
)
{
reg
.
val
[
0
]
=
vec_splats
(
v
);
reg
.
val
[
1
]
=
vec_splats
(
v
);
}
explicit
FP32Vec8
()
{
reg
.
val
[
0
]
=
vec_splats
(
0.0
f
);
reg
.
val
[
1
]
=
vec_splats
(
0.0
f
);
}
explicit
FP32Vec8
(
const
float
*
ptr
)
{
reg
.
val
[
0
]
=
vec_xl
(
0
,
ptr
);
reg
.
val
[
1
]
=
vec_xl
(
16
,
ptr
);
}
explicit
FP32Vec8
(
f32x4x2_t
data
)
:
reg
(
data
)
{}
explicit
FP32Vec8
(
const
FP32Vec8
&
data
)
{
reg
.
val
[
0
]
=
data
.
reg
.
val
[
0
];
reg
.
val
[
1
]
=
data
.
reg
.
val
[
1
];
}
explicit
FP32Vec8
(
const
BF16Vec8
&
v
)
{
reg
.
val
[
0
]
=
(
__vector
float
)
vec_mergeh
(
zero
,
v
.
reg
);
reg
.
val
[
1
]
=
(
__vector
float
)
vec_mergel
(
zero
,
v
.
reg
);
}
float
reduce_sum
()
const
{
AliasReg
ar
;
ar
.
reg
=
reg
;
float
result
=
0
;
unroll_loop
<
int
,
VEC_ELEM_NUM
>
(
[
&
result
,
&
ar
](
int
i
)
{
result
+=
ar
.
values
[
i
];
});
return
result
;
}
FP32Vec8
exp
()
const
{
// TODO: Vectorize this
AliasReg
ar
;
ar
.
reg
=
reg
;
f32x4x4_t
ret
;
ret
.
val
[
0
][
0
]
=
std
::
exp
(
ar
.
values
[
0
]);
ret
.
val
[
0
][
1
]
=
std
::
exp
(
ar
.
values
[
1
]);
ret
.
val
[
0
][
2
]
=
std
::
exp
(
ar
.
values
[
2
]);
ret
.
val
[
0
][
3
]
=
std
::
exp
(
ar
.
values
[
3
]);
ret
.
val
[
1
][
0
]
=
std
::
exp
(
ar
.
values
[
4
]);
ret
.
val
[
1
][
1
]
=
std
::
exp
(
ar
.
values
[
5
]);
ret
.
val
[
1
][
2
]
=
std
::
exp
(
ar
.
values
[
6
]);
ret
.
val
[
1
][
3
]
=
std
::
exp
(
ar
.
values
[
7
]);
return
FP32Vec8
(
f32x4x2_t
({
ret
.
val
[
0
],
ret
.
val
[
1
]}));
}
FP32Vec8
tanh
()
const
{
// TODO: Vectorize this
AliasReg
ar
;
ar
.
reg
=
reg
;
f32x4x4_t
ret
;
ret
.
val
[
0
][
0
]
=
std
::
tanh
(
ar
.
values
[
0
]);
ret
.
val
[
0
][
1
]
=
std
::
tanh
(
ar
.
values
[
1
]);
ret
.
val
[
0
][
2
]
=
std
::
tanh
(
ar
.
values
[
2
]);
ret
.
val
[
0
][
3
]
=
std
::
tanh
(
ar
.
values
[
3
]);
ret
.
val
[
1
][
0
]
=
std
::
tanh
(
ar
.
values
[
4
]);
ret
.
val
[
1
][
1
]
=
std
::
tanh
(
ar
.
values
[
5
]);
ret
.
val
[
1
][
2
]
=
std
::
tanh
(
ar
.
values
[
6
]);
ret
.
val
[
1
][
3
]
=
std
::
tanh
(
ar
.
values
[
7
]);
return
FP32Vec8
(
f32x4x2_t
({
ret
.
val
[
0
],
ret
.
val
[
1
]}));
}
FP32Vec8
er
()
const
{
// TODO: Vectorize this
AliasReg
ar
;
ar
.
reg
=
reg
;
f32x4x4_t
ret
;
ret
.
val
[
0
][
0
]
=
std
::
erf
(
ar
.
values
[
0
]);
ret
.
val
[
0
][
1
]
=
std
::
erf
(
ar
.
values
[
1
]);
ret
.
val
[
0
][
2
]
=
std
::
erf
(
ar
.
values
[
2
]);
ret
.
val
[
0
][
3
]
=
std
::
erf
(
ar
.
values
[
3
]);
ret
.
val
[
1
][
0
]
=
std
::
erf
(
ar
.
values
[
4
]);
ret
.
val
[
1
][
1
]
=
std
::
erf
(
ar
.
values
[
5
]);
ret
.
val
[
1
][
2
]
=
std
::
erf
(
ar
.
values
[
6
]);
ret
.
val
[
1
][
3
]
=
std
::
erf
(
ar
.
values
[
7
]);
return
FP32Vec8
(
f32x4x2_t
({
ret
.
val
[
0
],
ret
.
val
[
1
]}));
}
FP32Vec8
operator
*
(
const
FP32Vec8
&
b
)
const
{
return
FP32Vec8
(
{
vec_mul
(
reg
.
val
[
0
],
b
.
reg
.
val
[
0
]),
vec_mul
(
reg
.
val
[
1
],
b
.
reg
.
val
[
1
])});
}
FP32Vec8
operator
+
(
const
FP32Vec8
&
b
)
const
{
return
FP32Vec8
(
{
vec_add
(
reg
.
val
[
0
],
b
.
reg
.
val
[
0
]),
vec_add
(
reg
.
val
[
1
],
b
.
reg
.
val
[
1
])});
}
FP32Vec8
operator
-
(
const
FP32Vec8
&
b
)
const
{
return
FP32Vec8
(
{
vec_sub
(
reg
.
val
[
0
],
b
.
reg
.
val
[
0
]),
vec_sub
(
reg
.
val
[
1
],
b
.
reg
.
val
[
1
])});
}
FP32Vec8
operator
/
(
const
FP32Vec8
&
b
)
const
{
return
FP32Vec8
(
{
vec_div
(
reg
.
val
[
0
],
b
.
reg
.
val
[
0
]),
vec_div
(
reg
.
val
[
1
],
b
.
reg
.
val
[
1
])});
}
void
save
(
float
*
ptr
)
const
{
vec_xst
(
reg
.
val
[
0
],
0
,
ptr
);
vec_xst
(
reg
.
val
[
1
],
16
,
ptr
);
}
};
struct
FP32Vec16
:
public
Vec
<
FP32Vec16
>
{
constexpr
static
int
VEC_ELEM_NUM
=
16
;
union
AliasReg
{
f32x4x4_t
reg
;
float
values
[
VEC_ELEM_NUM
];
};
f32x4x4_t
reg
;
explicit
FP32Vec16
(
float
v
)
{
reg
.
val
[
0
]
=
vec_splats
(
v
);
reg
.
val
[
1
]
=
vec_splats
(
v
);
reg
.
val
[
2
]
=
vec_splats
(
v
);
reg
.
val
[
3
]
=
vec_splats
(
v
);
}
explicit
FP32Vec16
()
{
reg
.
val
[
0
]
=
vec_splats
(
0.0
f
);
reg
.
val
[
1
]
=
vec_splats
(
0.0
f
);
reg
.
val
[
2
]
=
vec_splats
(
0.0
f
);
reg
.
val
[
3
]
=
vec_splats
(
0.0
f
);
}
explicit
FP32Vec16
(
const
float
*
ptr
)
{
reg
.
val
[
0
]
=
vec_xl
(
0
,
ptr
);
reg
.
val
[
1
]
=
vec_xl
(
16
,
ptr
);
reg
.
val
[
2
]
=
vec_xl
(
32
,
ptr
);
reg
.
val
[
3
]
=
vec_xl
(
48
,
ptr
);
}
explicit
FP32Vec16
(
f32x4x4_t
data
)
:
reg
(
data
)
{}
explicit
FP32Vec16
(
const
FP32Vec16
&
data
)
{
reg
.
val
[
0
]
=
data
.
reg
.
val
[
0
];
reg
.
val
[
1
]
=
data
.
reg
.
val
[
1
];
reg
.
val
[
2
]
=
data
.
reg
.
val
[
2
];
reg
.
val
[
3
]
=
data
.
reg
.
val
[
3
];
}
explicit
FP32Vec16
(
const
FP32Vec4
&
data
)
{
reg
.
val
[
0
]
=
data
.
reg
;
reg
.
val
[
1
]
=
data
.
reg
;
reg
.
val
[
2
]
=
data
.
reg
;
reg
.
val
[
3
]
=
data
.
reg
;
}
explicit
FP32Vec16
(
const
FP32Vec8
&
data
)
{
reg
.
val
[
0
]
=
data
.
reg
.
val
[
0
];
reg
.
val
[
1
]
=
data
.
reg
.
val
[
1
];
reg
.
val
[
2
]
=
data
.
reg
.
val
[
0
];
reg
.
val
[
3
]
=
data
.
reg
.
val
[
1
];
}
explicit
FP32Vec16
(
const
BF16Vec16
&
v
)
{
reg
.
val
[
0
]
=
(
__vector
float
)
vec_mergeh
(
zero
,
v
.
reg
.
val
[
0
]);
reg
.
val
[
1
]
=
(
__vector
float
)
vec_mergel
(
zero
,
v
.
reg
.
val
[
0
]);
reg
.
val
[
2
]
=
(
__vector
float
)
vec_mergeh
(
zero
,
v
.
reg
.
val
[
1
]);
reg
.
val
[
3
]
=
(
__vector
float
)
vec_mergel
(
zero
,
v
.
reg
.
val
[
1
]);
}
explicit
FP32Vec16
(
const
BF16Vec8
&
v
)
:
FP32Vec16
(
FP32Vec8
(
v
))
{}
FP32Vec16
operator
*
(
const
FP32Vec16
&
b
)
const
{
return
FP32Vec16
(
f32x4x4_t
({
vec_mul
(
reg
.
val
[
0
],
b
.
reg
.
val
[
0
]),
vec_mul
(
reg
.
val
[
1
],
b
.
reg
.
val
[
1
]),
vec_mul
(
reg
.
val
[
2
],
b
.
reg
.
val
[
2
]),
vec_mul
(
reg
.
val
[
3
],
b
.
reg
.
val
[
3
])}));
}
FP32Vec16
operator
+
(
const
FP32Vec16
&
b
)
const
{
return
FP32Vec16
(
f32x4x4_t
({
vec_add
(
reg
.
val
[
0
],
b
.
reg
.
val
[
0
]),
vec_add
(
reg
.
val
[
1
],
b
.
reg
.
val
[
1
]),
vec_add
(
reg
.
val
[
2
],
b
.
reg
.
val
[
2
]),
vec_add
(
reg
.
val
[
3
],
b
.
reg
.
val
[
3
])}));
}
FP32Vec16
operator
-
(
const
FP32Vec16
&
b
)
const
{
return
FP32Vec16
(
f32x4x4_t
({
vec_sub
(
reg
.
val
[
0
],
b
.
reg
.
val
[
0
]),
vec_sub
(
reg
.
val
[
1
],
b
.
reg
.
val
[
1
]),
vec_sub
(
reg
.
val
[
2
],
b
.
reg
.
val
[
2
]),
vec_sub
(
reg
.
val
[
3
],
b
.
reg
.
val
[
3
])}));
}
FP32Vec16
operator
/
(
const
FP32Vec16
&
b
)
const
{
return
FP32Vec16
(
f32x4x4_t
({
vec_div
(
reg
.
val
[
0
],
b
.
reg
.
val
[
0
]),
vec_div
(
reg
.
val
[
1
],
b
.
reg
.
val
[
1
]),
vec_div
(
reg
.
val
[
2
],
b
.
reg
.
val
[
2
]),
vec_div
(
reg
.
val
[
3
],
b
.
reg
.
val
[
3
])}));
}
float
reduce_sum
()
const
{
AliasReg
ar
;
ar
.
reg
=
reg
;
float
result
=
0
;
unroll_loop
<
int
,
VEC_ELEM_NUM
>
(
[
&
result
,
&
ar
](
int
i
)
{
result
+=
ar
.
values
[
i
];
});
return
result
;
}
template
<
int
group_size
>
float
reduce_sub_sum
(
int
idx
)
{
static_assert
(
VEC_ELEM_NUM
%
group_size
==
0
);
AliasReg
ar
;
ar
.
reg
=
reg
;
float
result
=
0
;
const
int
start
=
idx
*
group_size
;
unroll_loop
<
int
,
group_size
>
(
[
&
result
,
&
start
,
ar
](
int
i
)
{
result
+=
ar
.
values
[
start
+
i
];
});
return
result
;
}
void
save
(
float
*
ptr
)
const
{
vec_xst
(
reg
.
val
[
0
],
0
,
ptr
);
vec_xst
(
reg
.
val
[
1
],
16
,
ptr
);
vec_xst
(
reg
.
val
[
2
],
32
,
ptr
);
vec_xst
(
reg
.
val
[
3
],
48
,
ptr
);
}
};
template
<
typename
T
>
struct
VecType
{
using
vec_type
=
void
;
};
template
<
typename
T
>
using
vec_t
=
typename
VecType
<
T
>::
vec_type
;
template
<
>
struct
VecType
<
float
>
{
using
vec_type
=
FP32Vec8
;
};
template
<
>
struct
VecType
<
c10
::
BFloat16
>
{
using
vec_type
=
BF16Vec8
;
};
template
<
typename
T
>
void
storeFP32
(
float
v
,
T
*
ptr
)
{
*
ptr
=
v
;
}
inline
void
fma
(
FP32Vec16
&
acc
,
FP32Vec16
&
a
,
FP32Vec16
&
b
)
{
acc
=
acc
+
a
*
b
;
}
namespace
c10
{
struct
BFloat16
{
uint16_t
value
;
// Assume BFloat16 is defined as a struct containing a 16-bit
// value.
};
}
// namespace c10
template
<
>
inline
void
storeFP32
<
c10
::
BFloat16
>
(
float
v
,
c10
::
BFloat16
*
ptr
)
{
c10
::
BFloat16
__attribute__
((
__may_alias__
))
*
v_ptr
=
reinterpret_cast
<
c10
::
BFloat16
*>
(
&
v
);
*
ptr
=
*
(
v_ptr
+
1
);
}
#ifndef __VEC_CLASS_FP_NAN
#define __VEC_CLASS_FP_NAN (1 << 6)
#endif
const
static
__vector
unsigned
char
omask
=
{
2
,
3
,
6
,
7
,
10
,
11
,
14
,
15
,
18
,
19
,
22
,
23
,
26
,
27
,
30
,
31
};
const
static
__vector
unsigned
int
bias
=
{
0x00007fff
,
0x00007fff
,
0x00007fff
,
0x00007fff
};
const
static
__vector
unsigned
int
nan
=
{
0x7fc00000
,
0x7fc00000
,
0x7fc00000
,
0x7fc00000
};
const
static
__vector
unsigned
int
sh16
=
{
16
,
16
,
16
,
16
};
const
static
__vector
unsigned
int
one
=
{
1
,
1
,
1
,
1
};
inline
BF16Vec8
::
BF16Vec8
(
const
FP32Vec8
&
v
)
{
__vector
unsigned
int
inp0
=
(
__vector
unsigned
int
)(
v
.
reg
.
val
[
0
]);
__vector
unsigned
int
inp1
=
(
__vector
unsigned
int
)(
v
.
reg
.
val
[
1
]);
int
cc
;
__vector
__bool
int
sel0
=
vec_fp_test_data_class
(
v
.
reg
.
val
[
0
],
__VEC_CLASS_FP_NAN
,
&
cc
);
__vector
__bool
int
sel1
=
vec_fp_test_data_class
(
v
.
reg
.
val
[
1
],
__VEC_CLASS_FP_NAN
,
&
cc
);
inp0
=
vec_sel
(
inp0
,
nan
,
sel0
)
>>
sh16
;
inp1
=
vec_sel
(
inp1
,
nan
,
sel1
)
>>
sh16
;
reg
=
(
__vector
signed
short
)
vec_perm
(
inp0
,
inp1
,
omask
);
}
inline
BF16Vec16
::
BF16Vec16
(
const
FP32Vec16
&
v
)
{
__vector
unsigned
int
inp0
=
(
__vector
unsigned
int
)(
v
.
reg
.
val
[
0
]);
__vector
unsigned
int
inp1
=
(
__vector
unsigned
int
)(
v
.
reg
.
val
[
1
]);
__vector
unsigned
int
inp2
=
(
__vector
unsigned
int
)(
v
.
reg
.
val
[
2
]);
__vector
unsigned
int
inp3
=
(
__vector
unsigned
int
)(
v
.
reg
.
val
[
3
]);
int
cc
;
__vector
__bool
int
sel0
=
vec_fp_test_data_class
(
v
.
reg
.
val
[
0
],
__VEC_CLASS_FP_NAN
,
&
cc
);
__vector
__bool
int
sel1
=
vec_fp_test_data_class
(
v
.
reg
.
val
[
1
],
__VEC_CLASS_FP_NAN
,
&
cc
);
__vector
__bool
int
sel2
=
vec_fp_test_data_class
(
v
.
reg
.
val
[
2
],
__VEC_CLASS_FP_NAN
,
&
cc
);
__vector
__bool
int
sel3
=
vec_fp_test_data_class
(
v
.
reg
.
val
[
3
],
__VEC_CLASS_FP_NAN
,
&
cc
);
inp0
=
vec_sel
(
inp0
,
nan
,
sel0
)
>>
sh16
;
inp1
=
vec_sel
(
inp1
,
nan
,
sel1
)
>>
sh16
;
inp2
=
vec_sel
(
inp2
,
nan
,
sel2
)
>>
sh16
;
inp3
=
vec_sel
(
inp3
,
nan
,
sel3
)
>>
sh16
;
reg
.
val
[
0
]
=
(
__vector
signed
short
)
vec_perm
(
inp0
,
inp1
,
omask
);
reg
.
val
[
1
]
=
(
__vector
signed
short
)
vec_perm
(
inp2
,
inp3
,
omask
);
}
inline
void
prefetch
(
const
void
*
addr
)
{
void
__dcbt
(
const
void
*
addr
);
}
};
// namespace vec_op
#endif
\ No newline at end of file
csrc/cpu/cpu_types_x86.hpp
View file @
469e903b
...
@@ -16,9 +16,18 @@ namespace vec_op {
...
@@ -16,9 +16,18 @@ namespace vec_op {
AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__)
AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__)
#define VLLM_DISPATCH_CASE_FLOATING_TYPES_FP8(...) \
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Float8_e5m2, __VA_ARGS__)
#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))
AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))
#define VLLM_DISPATCH_FLOATING_TYPES_WITH_E5M2(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH(TYPE, NAME, \
VLLM_DISPATCH_CASE_FLOATING_TYPES_FP8(__VA_ARGS__))
#ifndef CPU_OP_GUARD
#ifndef CPU_OP_GUARD
#define CPU_KERNEL_GUARD_IN(NAME)
#define CPU_KERNEL_GUARD_IN(NAME)
#define CPU_KERNEL_GUARD_OUT(NAME)
#define CPU_KERNEL_GUARD_OUT(NAME)
...
...
csrc/cpu/pos_encoding.cpp
View file @
469e903b
...
@@ -170,7 +170,7 @@ void rotary_embedding_gptj_impl(
...
@@ -170,7 +170,7 @@ void rotary_embedding_gptj_impl(
void
rotary_embedding
(
torch
::
Tensor
&
positions
,
torch
::
Tensor
&
query
,
void
rotary_embedding
(
torch
::
Tensor
&
positions
,
torch
::
Tensor
&
query
,
torch
::
Tensor
&
key
,
int64_t
head_size
,
torch
::
Tensor
&
key
,
int64_t
head_size
,
torch
::
Tensor
&
cos_sin_cache
,
bool
is_neox
)
{
torch
::
Tensor
&
cos_sin_cache
,
bool
is_neox
)
{
int
num_tokens
=
query
.
numel
()
/
query
.
size
(
-
1
);
int
num_tokens
=
positions
.
numel
(
);
int
rot_dim
=
cos_sin_cache
.
size
(
1
);
int
rot_dim
=
cos_sin_cache
.
size
(
1
);
int
num_heads
=
query
.
size
(
-
1
)
/
head_size
;
int
num_heads
=
query
.
size
(
-
1
)
/
head_size
;
int
num_kv_heads
=
key
.
size
(
-
1
)
/
head_size
;
int
num_kv_heads
=
key
.
size
(
-
1
)
/
head_size
;
...
...
csrc/cpu/quant.cpp
View file @
469e903b
...
@@ -25,7 +25,7 @@ struct KernelVecType<c10::BFloat16> {
...
@@ -25,7 +25,7 @@ struct KernelVecType<c10::BFloat16> {
template
<
>
template
<
>
struct
KernelVecType
<
c10
::
Half
>
{
struct
KernelVecType
<
c10
::
Half
>
{
#ifdef
__powerpc64__
#if
def
ined(
__powerpc64__
) || defined(__s390x__)
// Power architecture-specific vector type
// Power architecture-specific vector type
using
load_vec_type
=
vec_op
::
FP32Vec16
;
using
load_vec_type
=
vec_op
::
FP32Vec16
;
#else
#else
...
...
csrc/cuda_utils.h
View file @
469e903b
...
@@ -2,10 +2,14 @@
...
@@ -2,10 +2,14 @@
#include <stdio.h>
#include <stdio.h>
#if defined(__CUDACC__) || defined(_NVHPC_CUDA)
#if defined(__HIPCC__)
#define HOST_DEVICE_INLINE __forceinline__ __host__ __device__
#define HOST_DEVICE_INLINE __host__ __device__
#define DEVICE_INLINE __forceinline__ __device__
#define DEVICE_INLINE __device__
#define HOST_INLINE __forceinline__ __host__
#define HOST_INLINE __host__
#elif defined(__CUDACC__) || defined(_NVHPC_CUDA)
#define HOST_DEVICE_INLINE __host__ __device__ __forceinline__
#define DEVICE_INLINE __device__ __forceinline__
#define HOST_INLINE __host__ __forceinline__
#else
#else
#define HOST_DEVICE_INLINE inline
#define HOST_DEVICE_INLINE inline
#define DEVICE_INLINE inline
#define DEVICE_INLINE inline
...
@@ -25,3 +29,13 @@
...
@@ -25,3 +29,13 @@
int64_t
get_device_attribute
(
int64_t
attribute
,
int64_t
device_id
);
int64_t
get_device_attribute
(
int64_t
attribute
,
int64_t
device_id
);
int64_t
get_max_shared_memory_per_block_device_attribute
(
int64_t
device_id
);
int64_t
get_max_shared_memory_per_block_device_attribute
(
int64_t
device_id
);
namespace
cuda_utils
{
template
<
typename
T
>
HOST_DEVICE_INLINE
constexpr
std
::
enable_if_t
<
std
::
is_integral_v
<
T
>
,
T
>
ceil_div
(
T
a
,
T
b
)
{
return
(
a
+
b
-
1
)
/
b
;
}
};
// namespace cuda_utils
\ No newline at end of file
csrc/custom_all_reduce.cu
View file @
469e903b
...
@@ -142,3 +142,44 @@ void register_graph_buffers(fptr_t _fa,
...
@@ -142,3 +142,44 @@ void register_graph_buffers(fptr_t _fa,
bytes
.
reserve
(
handles
.
size
());
bytes
.
reserve
(
handles
.
size
());
fa
->
register_graph_buffers
(
bytes
,
offsets
);
fa
->
register_graph_buffers
(
bytes
,
offsets
);
}
}
std
::
tuple
<
fptr_t
,
torch
::
Tensor
>
allocate_shared_buffer_and_handle
(
int64_t
size
)
{
auto
device_index
=
c10
::
cuda
::
current_device
();
at
::
DeviceGuard
device_guard
(
at
::
Device
(
at
::
DeviceType
::
CUDA
,
device_index
));
void
*
buffer
;
cudaStreamCaptureMode
mode
=
cudaStreamCaptureModeRelaxed
;
auto
stream
=
c10
::
cuda
::
getCurrentCUDAStream
().
stream
();
AT_CUDA_CHECK
(
cudaThreadExchangeStreamCaptureMode
(
&
mode
));
#if defined(USE_ROCM)
// data buffers need to be "uncached" for signal on MI200
AT_CUDA_CHECK
(
hipExtMallocWithFlags
((
void
**
)
&
buffer
,
size
,
hipDeviceMallocUncached
));
#else
AT_CUDA_CHECK
(
cudaMalloc
((
void
**
)
&
buffer
,
size
));
#endif
AT_CUDA_CHECK
(
cudaMemsetAsync
(
buffer
,
0
,
size
,
stream
));
AT_CUDA_CHECK
(
cudaStreamSynchronize
(
stream
));
AT_CUDA_CHECK
(
cudaThreadExchangeStreamCaptureMode
(
&
mode
));
auto
options
=
torch
::
TensorOptions
().
dtype
(
torch
::
kUInt8
).
device
(
torch
::
kCPU
);
auto
handle
=
torch
::
empty
({
static_cast
<
int64_t
>
(
sizeof
(
cudaIpcMemHandle_t
))},
options
);
AT_CUDA_CHECK
(
cudaIpcGetMemHandle
((
cudaIpcMemHandle_t
*
)
handle
.
data_ptr
(),
buffer
));
return
std
::
make_tuple
(
reinterpret_cast
<
fptr_t
>
(
buffer
),
handle
);
}
fptr_t
open_mem_handle
(
torch
::
Tensor
&
mem_handle
)
{
void
*
ipc_ptr
;
AT_CUDA_CHECK
(
cudaIpcOpenMemHandle
(
(
void
**
)
&
ipc_ptr
,
*
((
const
cudaIpcMemHandle_t
*
)
mem_handle
.
data_ptr
()),
cudaIpcMemLazyEnablePeerAccess
));
return
reinterpret_cast
<
fptr_t
>
(
ipc_ptr
);
}
void
free_shared_buffer
(
fptr_t
buffer
)
{
AT_CUDA_CHECK
(
cudaFree
(
reinterpret_cast
<
void
*>
(
buffer
)));
}
csrc/custom_all_reduce.cuh
View file @
469e903b
...
@@ -5,6 +5,10 @@
...
@@ -5,6 +5,10 @@
#include <cuda_fp16.h>
#include <cuda_fp16.h>
#include <cuda_runtime.h>
#include <cuda_runtime.h>
#if defined(USE_ROCM)
typedef
__hip_bfloat16
nv_bfloat16
;
#endif
#include <iostream>
#include <iostream>
#include <array>
#include <array>
#include <limits>
#include <limits>
...
@@ -12,6 +16,7 @@
...
@@ -12,6 +16,7 @@
#include <unordered_map>
#include <unordered_map>
#include <vector>
#include <vector>
namespace
vllm
{
#define CUDACHECK(cmd) \
#define CUDACHECK(cmd) \
do { \
do { \
cudaError_t e = cmd; \
cudaError_t e = cmd; \
...
@@ -22,24 +27,37 @@
...
@@ -22,24 +27,37 @@
} \
} \
} while (0)
} while (0)
namespace
vllm
{
// Maximal number of blocks in allreduce kernel.
constexpr
int
kMaxBlocks
=
36
;
constexpr
int
kMaxBlocks
=
36
;
// Default number of blocks in allreduce kernel.
#ifndef USE_ROCM
const
int
defaultBlockLimit
=
36
;
CUpointer_attribute
rangeStartAddrAttr
=
CUDA_POINTER_ATTRIBUTE_RANGE_START_ADDR
;
#else
const
int
defaultBlockLimit
=
16
;
hipPointer_attribute
rangeStartAddrAttr
=
HIP_POINTER_ATTRIBUTE_RANGE_START_ADDR
;
#endif
// Counter may overflow, but it's fine since unsigned int overflow is
// Counter may overflow, but it's fine since unsigned int overflow is
// well-defined behavior.
// well-defined behavior.
using
FlagType
=
uint32_t
;
using
FlagType
=
uint32_t
;
// Two sets of peer counters are needed for two syncs. The reason is that
// it's possible for peer GPU block to arrive at the second sync point while
// the current GPU block haven't passed the first sync point. Thus, peer GPU
// may write counter+1 while current GPU is busy waiting for counter. We use
// alternating counter array to avoid this possibility.
struct
Signal
{
struct
Signal
{
alignas
(
128
)
FlagType
self_counter
[
kMaxBlocks
][
8
];
alignas
(
128
)
FlagType
start
[
kMaxBlocks
][
8
];
// Two sets of peer counters are needed for two syncs. The reason is that
alignas
(
128
)
FlagType
end
[
kMaxBlocks
][
8
];
// it's possible for peer GPU block to arrive at the second sync point while
alignas
(
128
)
FlagType
_flag
[
kMaxBlocks
];
// incremental flags for each rank
// the current GPU block haven't passed the first sync point. Thus, peer GPU
// may write counter+1 while current GPU is busy waiting for counter. We use
// alternating counter array to avoid this possibility.
alignas
(
128
)
FlagType
peer_counter
[
2
][
kMaxBlocks
][
8
];
};
};
struct
__align__
(
16
)
RankData
{
struct
__align__
(
16
)
RankData
{
const
void
*
__restrict__
ptrs
[
8
];
const
void
*
ptrs
[
8
];
};
};
struct
__align__
(
16
)
RankSignals
{
struct
__align__
(
16
)
RankSignals
{
...
@@ -134,27 +152,29 @@ DINLINE O downcast(array_t<float, O::size> val) {
...
@@ -134,27 +152,29 @@ DINLINE O downcast(array_t<float, O::size> val) {
}
}
}
}
#if !defined(USE_ROCM)
static
DINLINE
void
st_flag_release
(
FlagType
*
flag_addr
,
FlagType
flag
)
{
static
DINLINE
void
st_flag_release
(
FlagType
*
flag_addr
,
FlagType
flag
)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700
asm
volatile
(
"st.release.sys.global.u32 [%1], %0;"
::
"r"
(
flag
),
asm
volatile
(
"st.release.sys.global.u32 [%1], %0;"
::
"r"
(
flag
),
"l"
(
flag_addr
));
"l"
(
flag_addr
));
#else
#else
asm
volatile
(
"membar.sys; st.volatile.global.u32 [%1], %0;"
::
"r"
(
flag
),
asm
volatile
(
"membar.sys; st.volatile.global.u32 [%1], %0;"
::
"r"
(
flag
),
"l"
(
flag_addr
));
"l"
(
flag_addr
));
#endif
#endif
}
}
static
DINLINE
FlagType
ld_flag_acquire
(
FlagType
*
flag_addr
)
{
static
DINLINE
FlagType
ld_flag_acquire
(
FlagType
*
flag_addr
)
{
FlagType
flag
;
FlagType
flag
;
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700
asm
volatile
(
"ld.acquire.sys.global.u32 %0, [%1];"
asm
volatile
(
"ld.acquire.sys.global.u32 %0, [%1];"
:
"=r"
(
flag
)
:
"=r"
(
flag
)
:
"l"
(
flag_addr
));
:
"l"
(
flag_addr
));
#else
#else
asm
volatile
(
"ld.volatile.global.u32 %0, [%1]; membar.gl;"
asm
volatile
(
"ld.volatile.global.u32 %0, [%1]; membar.gl;"
:
"=r"
(
flag
)
:
"=r"
(
flag
)
:
"l"
(
flag_addr
));
:
"l"
(
flag_addr
));
#endif
#endif
return
flag
;
return
flag
;
}
}
...
@@ -170,37 +190,108 @@ static DINLINE FlagType ld_flag_volatile(FlagType* flag_addr) {
...
@@ -170,37 +190,108 @@ static DINLINE FlagType ld_flag_volatile(FlagType* flag_addr) {
return
flag
;
return
flag
;
}
}
// is_start: whether this is the very first synchronization barrier.
// This function is meant to be used as the first synchronization in the all
// need_fence: whether a memory fence is needed. If true, a release-acquire
// reduce kernel. Thus, it doesn't need to make any visibility guarantees for
// semantic is used to enforce memory access order before and after this
// prior memory accesses. Note: volatile writes will not be reordered against
// barrier.
// other volatile writes.
template
<
int
ngpus
,
bool
is_start
,
bool
need_fence
=
false
>
template
<
int
ngpus
>
DINLINE
void
multi_gpu_barrier
(
const
RankSignals
&
sg
,
Signal
*
self_sg
,
DINLINE
void
start_sync
(
const
RankSignals
&
sg
,
Signal
*
self_sg
,
int
rank
)
{
int
rank
)
{
uint32_t
flag
=
self_sg
->
_flag
[
blockIdx
.
x
]
+
1
;
if
constexpr
(
!
is_start
)
__syncthreads
();
static_assert
(
!
(
is_start
&&
need_fence
));
// Start barrier shouldn't need fence.
if
(
threadIdx
.
x
<
ngpus
)
{
if
(
threadIdx
.
x
<
ngpus
)
{
// Increment the counter. Technically we only need one counter, but we use
auto
peer_counter_ptr
=
&
sg
.
signals
[
threadIdx
.
x
]
->
start
[
blockIdx
.
x
][
rank
];
// multiple per block to eliminate the need to share the counter via smem.
auto
self_counter_ptr
=
&
self_sg
->
start
[
blockIdx
.
x
][
threadIdx
.
x
];
auto
val
=
self_sg
->
self_counter
[
blockIdx
.
x
][
threadIdx
.
x
]
+=
1
;
// Write the expected counter value to peer and wait for correct value
// from peer.
st_flag_volatile
(
peer_counter_ptr
,
flag
);
while
(
ld_flag_volatile
(
self_counter_ptr
)
!=
flag
);
}
__syncthreads
();
// use one thread to update flag
if
(
threadIdx
.
x
==
0
)
self_sg
->
_flag
[
blockIdx
.
x
]
=
flag
;
}
// This function is meant to be used as the second or the final
// synchronization barrier in the all reduce kernel. If it's the final
// synchronization barrier, we don't need to make any visibility guarantees
// for prior memory accesses.
template
<
int
ngpus
,
bool
final_sync
=
false
>
DINLINE
void
end_sync
(
const
RankSignals
&
sg
,
Signal
*
self_sg
,
int
rank
)
{
__syncthreads
();
uint32_t
flag
=
self_sg
->
_flag
[
blockIdx
.
x
]
+
1
;
if
(
threadIdx
.
x
<
ngpus
)
{
auto
peer_counter_ptr
=
&
sg
.
signals
[
threadIdx
.
x
]
->
end
[
blockIdx
.
x
][
rank
];
auto
self_counter_ptr
=
&
self_sg
->
end
[
blockIdx
.
x
][
threadIdx
.
x
];
// Write the expected counter value to peer and wait for correct value from
// Write the expected counter value to peer and wait for correct value from
// peer.
// peer.
auto
peer_counter_ptr
=
if
constexpr
(
!
final_sync
)
{
&
sg
.
signals
[
threadIdx
.
x
]
->
peer_counter
[
val
%
2
][
blockIdx
.
x
][
rank
];
st_flag_release
(
peer_counter_ptr
,
flag
);
auto
self_counter_ptr
=
while
(
ld_flag_acquire
(
self_counter_ptr
)
!=
flag
);
&
self_sg
->
peer_counter
[
val
%
2
][
blockIdx
.
x
][
threadIdx
.
x
];
if
constexpr
(
need_fence
)
{
st_flag_release
(
peer_counter_ptr
,
val
);
while
(
ld_flag_acquire
(
self_counter_ptr
)
!=
val
);
}
else
{
}
else
{
st_flag_volatile
(
peer_counter_ptr
,
val
);
st_flag_volatile
(
peer_counter_ptr
,
flag
);
while
(
ld_flag_volatile
(
self_counter_ptr
)
!=
val
);
while
(
ld_flag_volatile
(
self_counter_ptr
)
!=
flag
);
}
}
}
}
if
constexpr
(
is_start
||
need_fence
)
__syncthreads
();
if
constexpr
(
!
final_sync
)
__syncthreads
();
// use one thread to update flag
if
(
threadIdx
.
x
==
0
)
self_sg
->
_flag
[
blockIdx
.
x
]
=
flag
;
}
#else
template
<
int
ngpus
>
DINLINE
void
start_sync
(
const
RankSignals
&
sg
,
Signal
*
self_sg
,
int
rank
)
{
uint32_t
flag
=
self_sg
->
_flag
[
blockIdx
.
x
]
+
1
;
if
(
threadIdx
.
x
<
ngpus
)
{
// simultaneously write to the corresponding flag of all ranks.
// Latency = 1 p2p write
// __scoped_atomic_store_n(&sg.signals[threadIdx.x]->start[blockIdx.x][rank],
// flag, __ATOMIC_RELAXED, __MEMORY_SCOPE_SYSTEM);
// // wait until we got true from all ranks
// while (__scoped_atomic_load_n(&self_sg->start[blockIdx.x][threadIdx.x],
// __ATOMIC_RELAXED,
// __MEMORY_SCOPE_DEVICE) < flag);
__atomic_store_n
(
&
sg
.
signals
[
threadIdx
.
x
]
->
start
[
blockIdx
.
x
][
rank
],
flag
,
__ATOMIC_RELAXED
);
// wait until we got true from all ranks
while
(
__atomic_load_n
(
&
self_sg
->
start
[
blockIdx
.
x
][
threadIdx
.
x
],
__ATOMIC_RELAXED
)
<
flag
);
}
__syncthreads
();
// use one thread to update flag
if
(
threadIdx
.
x
==
0
)
self_sg
->
_flag
[
blockIdx
.
x
]
=
flag
;
}
template
<
int
ngpus
,
bool
final_sync
=
false
>
DINLINE
void
end_sync
(
const
RankSignals
&
sg
,
Signal
*
self_sg
,
int
rank
)
{
__syncthreads
();
uint32_t
flag
=
self_sg
->
_flag
[
blockIdx
.
x
]
+
1
;
if
(
threadIdx
.
x
<
ngpus
)
{
// simultaneously write to the corresponding flag of all ranks.
// Latency = 1 p2p write
// __scoped_atomic_store_n(&sg.signals[threadIdx.x]->end[blockIdx.x][rank],
// flag,
// final_sync ? __ATOMIC_RELAXED : __ATOMIC_RELEASE,
// __MEMORY_SCOPE_SYSTEM);
// // wait until we got true from all ranks
// while (
// __scoped_atomic_load_n(&self_sg->end[blockIdx.x][threadIdx.x],
// final_sync ? __ATOMIC_RELAXED : __ATOMIC_ACQUIRE,
// __MEMORY_SCOPE_DEVICE) < flag);
__atomic_store_n
(
&
sg
.
signals
[
threadIdx
.
x
]
->
end
[
blockIdx
.
x
][
rank
],
flag
,
final_sync
?
__ATOMIC_RELAXED
:
__ATOMIC_RELEASE
);
// wait until we got true from all ranks
while
(
__atomic_load_n
(
&
self_sg
->
end
[
blockIdx
.
x
][
threadIdx
.
x
],
final_sync
?
__ATOMIC_RELAXED
:
__ATOMIC_ACQUIRE
)
<
flag
);
}
if
constexpr
(
!
final_sync
)
__syncthreads
();
// use one thread to update flag
if
(
threadIdx
.
x
==
0
)
self_sg
->
_flag
[
blockIdx
.
x
]
=
flag
;
}
}
#endif
template
<
typename
P
,
int
ngpus
,
typename
A
>
template
<
typename
P
,
int
ngpus
,
typename
A
>
DINLINE
P
packed_reduce
(
const
P
*
ptrs
[],
int
idx
)
{
DINLINE
P
packed_reduce
(
const
P
*
ptrs
[],
int
idx
)
{
A
tmp
=
upcast
(
ptrs
[
0
][
idx
]);
A
tmp
=
upcast
(
ptrs
[
0
][
idx
]);
...
@@ -220,13 +311,13 @@ __global__ void __launch_bounds__(512, 1)
...
@@ -220,13 +311,13 @@ __global__ void __launch_bounds__(512, 1)
// note: we don't reorder the address so the accumulation order is the same
// note: we don't reorder the address so the accumulation order is the same
// for all ranks, ensuring bitwise identical results
// for all ranks, ensuring bitwise identical results
auto
dp
=
*
_dp
;
auto
dp
=
*
_dp
;
multi_gpu_barrier
<
ngpus
,
true
>
(
sg
,
self_sg
,
rank
);
start_sync
<
ngpus
>
(
sg
,
self_sg
,
rank
);
// do the actual reduction
// do the actual reduction
for
(
int
idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
idx
<
size
;
for
(
int
idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
idx
<
size
;
idx
+=
gridDim
.
x
*
blockDim
.
x
)
{
idx
+=
gridDim
.
x
*
blockDim
.
x
)
{
((
P
*
)
result
)[
idx
]
=
packed_reduce
<
P
,
ngpus
,
A
>
((
const
P
**
)
&
dp
.
ptrs
[
0
],
idx
);
((
P
*
)
result
)[
idx
]
=
packed_reduce
<
P
,
ngpus
,
A
>
((
const
P
**
)
&
dp
.
ptrs
[
0
],
idx
);
}
}
multi_gpu_barrier
<
ngpus
,
fals
e
>
(
sg
,
self_sg
,
rank
);
end_sync
<
ngpus
,
tru
e
>
(
sg
,
self_sg
,
rank
);
}
}
template
<
typename
P
>
template
<
typename
P
>
...
@@ -255,12 +346,13 @@ __global__ void __launch_bounds__(512, 1)
...
@@ -255,12 +346,13 @@ __global__ void __launch_bounds__(512, 1)
tmps
[
i
]
=
get_tmp_buf
<
P
>
(
sg
.
signals
[
target
]);
tmps
[
i
]
=
get_tmp_buf
<
P
>
(
sg
.
signals
[
target
]);
}
}
auto
tmp_out
=
tmps
[
0
];
auto
tmp_out
=
tmps
[
0
];
multi_gpu_barrier
<
ngpus
,
true
>
(
sg
,
self_sg
,
rank
);
start_sync
<
ngpus
>
(
sg
,
self_sg
,
rank
);
// stage 1: reduce scatter
// stage 1: reduce scatter
for
(
int
idx
=
start
+
tid
;
idx
<
end
;
idx
+=
stride
)
{
for
(
int
idx
=
start
+
tid
;
idx
<
end
;
idx
+=
stride
)
{
tmp_out
[
idx
-
start
]
=
packed_reduce
<
P
,
ngpus
,
A
>
(
ptrs
,
idx
);
tmp_out
[
idx
-
start
]
=
packed_reduce
<
P
,
ngpus
,
A
>
(
ptrs
,
idx
);
}
}
multi_gpu_barrier
<
ngpus
,
false
,
true
>
(
sg
,
self_sg
,
rank
);
end_sync
<
ngpus
>
(
sg
,
self_sg
,
rank
);
// stage 2: allgather. Note: it's important to match the tid between
// stage 2: allgather. Note: it's important to match the tid between
// the two stages, because visibility across devices is only guaranteed
// the two stages, because visibility across devices is only guaranteed
...
@@ -290,7 +382,7 @@ class CustomAllreduce {
...
@@ -290,7 +382,7 @@ class CustomAllreduce {
bool
full_nvlink_
;
bool
full_nvlink_
;
RankSignals
sg_
;
RankSignals
sg_
;
// Stores an map from a pointer to its peer point
t
ers from all ranks.
// Stores an map from a pointer to its peer pointers from all ranks.
std
::
unordered_map
<
void
*
,
RankData
*>
buffers_
;
std
::
unordered_map
<
void
*
,
RankData
*>
buffers_
;
Signal
*
self_sg_
;
Signal
*
self_sg_
;
...
@@ -361,8 +453,7 @@ class CustomAllreduce {
...
@@ -361,8 +453,7 @@ class CustomAllreduce {
void
*
base_ptr
;
void
*
base_ptr
;
// note: must share the base address of each allocation, or we get wrong
// note: must share the base address of each allocation, or we get wrong
// address
// address
if
(
cuPointerGetAttribute
(
&
base_ptr
,
if
(
cuPointerGetAttribute
(
&
base_ptr
,
rangeStartAddrAttr
,
CU_POINTER_ATTRIBUTE_RANGE_START_ADDR
,
(
CUdeviceptr
)
ptr
)
!=
CUDA_SUCCESS
)
(
CUdeviceptr
)
ptr
)
!=
CUDA_SUCCESS
)
throw
std
::
runtime_error
(
"failed to get pointer attr"
);
throw
std
::
runtime_error
(
"failed to get pointer attr"
);
CUDACHECK
(
cudaIpcGetMemHandle
(
CUDACHECK
(
cudaIpcGetMemHandle
(
...
@@ -439,7 +530,7 @@ class CustomAllreduce {
...
@@ -439,7 +530,7 @@ class CustomAllreduce {
*/
*/
template
<
typename
T
>
template
<
typename
T
>
void
allreduce
(
cudaStream_t
stream
,
T
*
input
,
T
*
output
,
int
size
,
void
allreduce
(
cudaStream_t
stream
,
T
*
input
,
T
*
output
,
int
size
,
int
threads
=
512
,
int
block_limit
=
36
)
{
int
threads
=
512
,
int
block_limit
=
defaultBlockLimit
)
{
auto
d
=
packed_t
<
T
>::
P
::
size
;
auto
d
=
packed_t
<
T
>::
P
::
size
;
if
(
size
%
d
!=
0
)
if
(
size
%
d
!=
0
)
throw
std
::
runtime_error
(
throw
std
::
runtime_error
(
...
@@ -473,8 +564,6 @@ class CustomAllreduce {
...
@@ -473,8 +564,6 @@ class CustomAllreduce {
#define KL(ngpus, name) \
#define KL(ngpus, name) \
name<T, ngpus><<<blocks, threads, 0, stream>>>(ptrs, sg_, self_sg_, output, \
name<T, ngpus><<<blocks, threads, 0, stream>>>(ptrs, sg_, self_sg_, output, \
rank_, size);
rank_, size);
// TODO(hanzhi713): Threshold is different for A100 and H100.
// Add per device threshold.
#define REDUCE_CASE(ngpus) \
#define REDUCE_CASE(ngpus) \
case ngpus: { \
case ngpus: { \
if (world_size_ == 2) { \
if (world_size_ == 2) { \
...
@@ -497,7 +586,8 @@ class CustomAllreduce {
...
@@ -497,7 +586,8 @@ class CustomAllreduce {
REDUCE_CASE
(
8
)
REDUCE_CASE
(
8
)
default:
default:
throw
std
::
runtime_error
(
throw
std
::
runtime_error
(
"custom allreduce only supports num gpus in (2,4,6,8). Actual num "
"custom allreduce only supports num gpus in (2,4,6,8). Actual "
"num "
"gpus = "
+
"gpus = "
+
std
::
to_string
(
world_size_
));
std
::
to_string
(
world_size_
));
}
}
...
...
csrc/custom_all_reduce_test.cu
View file @
469e903b
...
@@ -20,9 +20,16 @@
...
@@ -20,9 +20,16 @@
#include <vector>
#include <vector>
#include "cuda_profiler_api.h"
#include "cuda_profiler_api.h"
#include "custom_all_reduce.cuh"
#include "mpi.h"
#include "mpi.h"
#include "nccl.h"
#ifdef USE_ROCM
#include <hip/hip_bf16.h>
typedef
__hip_bfloat16
nv_bfloat16
;
#include "rccl/rccl.h"
#include "custom_all_reduce_hip.cuh"
#else
#include "nccl.h"
#include "custom_all_reduce.cuh"
#endif
#define MPICHECK(cmd) \
#define MPICHECK(cmd) \
do { \
do { \
...
@@ -44,13 +51,16 @@
...
@@ -44,13 +51,16 @@
} while (0)
} while (0)
__global__
void
dummy_kernel
()
{
__global__
void
dummy_kernel
()
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700
#ifdef USE_ROCM
for
(
int
i
=
0
;
i
<
100
;
i
++
)
__nanosleep
(
1000000
);
// 100ms
#else
for
(
int
i
=
0
;
i
<
100
;
i
++
)
{
for
(
int
i
=
0
;
i
<
100
;
i
++
)
{
long
long
int
start
=
clock64
();
uint64_t
start
=
wall_clock64
();
while
(
clock64
()
-
start
<
150000000
);
// approximately 98.4ms on P40
uint64_t
cycles_elapsed
;
do
{
cycles_elapsed
=
wall_clock64
()
-
start
;
}
while
(
cycles_elapsed
<
100
);
}
}
#else
for
(
int
i
=
0
;
i
<
100
;
i
++
)
__nanosleep
(
1000000
);
// 100ms
#endif
#endif
}
}
...
@@ -121,8 +131,14 @@ void run(int myRank, int nRanks, ncclComm_t& comm, int threads, int block_limit,
...
@@ -121,8 +131,14 @@ void run(int myRank, int nRanks, ncclComm_t& comm, int threads, int block_limit,
* registration, they are allocated and registered together in the test for
* registration, they are allocated and registered together in the test for
* convenience.
* convenience.
*/
*/
#ifdef USE_ROCM
CUDACHECK
(
hipExtMallocWithFlags
(
(
void
**
)
&
buffer
,
2
*
data_size
*
sizeof
(
T
)
+
sizeof
(
vllm
::
Signal
),
hipDeviceMallocUncached
));
#else
CUDACHECK
(
CUDACHECK
(
cudaMalloc
(
&
buffer
,
2
*
data_size
*
sizeof
(
T
)
+
sizeof
(
vllm
::
Signal
)));
cudaMalloc
(
&
buffer
,
2
*
data_size
*
sizeof
(
T
)
+
sizeof
(
vllm
::
Signal
)));
#endif
CUDACHECK
(
CUDACHECK
(
cudaMemset
(
buffer
,
0
,
2
*
data_size
*
sizeof
(
T
)
+
sizeof
(
vllm
::
Signal
)));
cudaMemset
(
buffer
,
0
,
2
*
data_size
*
sizeof
(
T
)
+
sizeof
(
vllm
::
Signal
)));
CUDACHECK
(
cudaMalloc
(
&
self_data_copy
,
data_size
*
sizeof
(
T
)));
CUDACHECK
(
cudaMalloc
(
&
self_data_copy
,
data_size
*
sizeof
(
T
)));
...
@@ -135,26 +151,24 @@ void run(int myRank, int nRanks, ncclComm_t& comm, int threads, int block_limit,
...
@@ -135,26 +151,24 @@ void run(int myRank, int nRanks, ncclComm_t& comm, int threads, int block_limit,
void
*
rank_data
;
void
*
rank_data
;
size_t
rank_data_sz
=
16
*
1024
*
1024
;
size_t
rank_data_sz
=
16
*
1024
*
1024
;
CUDACHECK
(
cudaMalloc
(
&
rank_data
,
rank_data_sz
));
CUDACHECK
(
cudaMalloc
(
&
rank_data
,
rank_data_sz
));
vllm
::
Signal
*
ipc_ptrs
[
8
];
std
::
vector
<
int64_t
>
offsets
(
nRanks
,
0
);
for
(
int
i
=
0
;
i
<
nRanks
;
i
++
)
{
vllm
::
CustomAllreduce
fa
(
buffer
,
rank_data
,
rank_data_sz
,
data_handles
,
if
(
i
==
myRank
)
offsets
,
myRank
);
ipc_ptrs
[
i
]
=
buffer
;
else
CUDACHECK
(
cudaIpcOpenMemHandle
((
void
**
)
&
ipc_ptrs
[
i
],
data_handles
[
i
],
cudaIpcMemLazyEnablePeerAccess
));
}
vllm
::
CustomAllreduce
fa
(
ipc_ptrs
,
rank_data
,
rank_data_sz
,
myRank
,
nRanks
);
auto
*
self_data
=
auto
*
self_data
=
reinterpret_cast
<
T
*>
(
reinterpret_cast
<
char
*>
(
buffer
)
+
reinterpret_cast
<
T
*>
(
reinterpret_cast
<
char
*>
(
buffer
)
+
sizeof
(
vllm
::
Signal
)
+
data_size
*
sizeof
(
T
));
sizeof
(
vllm
::
Signal
)
+
data_size
*
sizeof
(
T
));
// hack buffer registration
// hack buffer registration
{
{
void
*
data
[
8
];
std
::
vector
<
std
::
string
>
handles
;
handles
.
reserve
(
nRanks
);
for
(
int
i
=
0
;
i
<
nRanks
;
i
++
)
{
for
(
int
i
=
0
;
i
<
nRanks
;
i
++
)
{
data
[
i
]
=
char
*
begin
=
(
char
*
)
&
data_handles
[
i
];
((
char
*
)
ipc_ptrs
[
i
])
+
sizeof
(
vllm
::
Signal
)
+
data_size
*
sizeof
(
T
);
char
*
end
=
(
char
*
)
&
data_handles
[
i
+
1
];
handles
.
emplace_back
(
begin
,
end
);
}
}
fa
.
register_buffer
(
data
);
std
::
vector
<
int64_t
>
offsets
(
nRanks
,
sizeof
(
vllm
::
Signal
)
+
data_size
*
sizeof
(
T
));
fa
.
register_buffer
(
handles
,
offsets
,
self_data
);
}
}
double
*
ground_truth
;
double
*
ground_truth
;
...
@@ -266,14 +280,14 @@ void run(int myRank, int nRanks, ncclComm_t& comm, int threads, int block_limit,
...
@@ -266,14 +280,14 @@ void run(int myRank, int nRanks, ncclComm_t& comm, int threads, int block_limit,
if
(
diff
>=
4e-2
)
{
if
(
diff
>=
4e-2
)
{
printf
(
printf
(
"Rank %d: Verification mismatch at %lld: %f != (my) %f, gt=%f
\n
"
,
"Rank %d: Verification mismatch at %lld: %f != (my) %f, gt=%f
\n
"
,
myRank
,
j
,
nccl_result
[
j
],
my_result
[
j
],
ground_truth
[
j
]);
myRank
,
j
,
nccl_result
[
j
],
my_result
[
j
],
ground_truth
[
j
]);
break
;
break
;
}
}
}
}
}
}
if
(
myRank
==
0
)
if
(
myRank
==
0
)
printf
(
"Test passed: nGPUs:%d, sz (kb): %d, %d, %d
\n
"
,
nRanks
,
printf
(
"Test passed: nGPUs:%d, sz (kb): %d, %d, %d
\n
"
,
nRanks
,
data_size
*
sizeof
(
T
)
/
1024
,
threads
,
block_limit
);
data_size
*
sizeof
(
T
)
/
1024
,
threads
,
block_limit
);
// long double nccl_diffs = 0.0;
// long double nccl_diffs = 0.0;
// long double my_diffs = 0.0;
// long double my_diffs = 0.0;
// for (int j = 0; j < data_size; j++) {
// for (int j = 0; j < data_size; j++) {
...
@@ -306,24 +320,27 @@ int main(int argc, char** argv) {
...
@@ -306,24 +320,27 @@ int main(int argc, char** argv) {
ncclComm_t
comm
;
ncclComm_t
comm
;
if
(
myRank
==
0
)
ncclGetUniqueId
(
&
id
);
if
(
myRank
==
0
)
ncclGetUniqueId
(
&
id
);
MPICHECK
(
MPI_Bcast
(
static_cast
<
void
*>
(
&
id
),
sizeof
(
id
),
MPI_BYTE
,
0
,
MPICHECK
(
MPI_Bcast
(
static_cast
<
void
*>
(
&
id
),
sizeof
(
id
),
MPI_BYTE
,
0
,
MPI_COMM_WORLD
));
MPI_COMM_WORLD
));
NCCLCHECK
(
ncclCommInitRank
(
&
comm
,
nRanks
,
id
,
myRank
));
NCCLCHECK
(
ncclCommInitRank
(
&
comm
,
nRanks
,
id
,
myRank
));
bool
performance_test
=
true
;
bool
performance_test
=
true
;
cudaProfilerStart
();
cudaProfilerStart
();
// Uncomment to scan through different block size configs.
// for (int threads : {256, 512}) {
// for (int threads : {256, 512, 1024}) {
// for (int block_limit = 16; block_limit < 112; block_limit += 4) {
// for (int block_limit = 16; block_limit < 112; block_limit += 4) {
// run<half>(myRank, nRanks, comm, threads, block_limit, 1024 * 1024,
// run<half>(myRank, nRanks, comm, threads, block_limit, 4096 * 1024);
// performance_test);
// }
// }
// }
// }
// Scan through different sizes to test performance.
#ifdef USE_ROCM
for
(
int
sz
=
512
;
sz
<=
(
8
<<
20
);
sz
*=
2
)
{
run
<
half
>
(
myRank
,
nRanks
,
comm
,
512
,
16
,
sz
+
8
*
47
,
performance_test
);
}
#else
for
(
int
sz
=
512
;
sz
<=
(
8
<<
20
);
sz
*=
2
)
{
for
(
int
sz
=
512
;
sz
<=
(
8
<<
20
);
sz
*=
2
)
{
run
<
half
>
(
myRank
,
nRanks
,
comm
,
512
,
36
,
sz
+
8
*
47
,
performance_test
);
run
<
half
>
(
myRank
,
nRanks
,
comm
,
512
,
36
,
sz
+
8
*
47
,
performance_test
);
}
}
#endif
cudaProfilerStop
();
cudaProfilerStop
();
MPICHECK
(
MPI_Finalize
());
MPICHECK
(
MPI_Finalize
());
return
EXIT_SUCCESS
;
return
EXIT_SUCCESS
;
}
}
csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c2x.hpp
View file @
469e903b
...
@@ -122,8 +122,8 @@ struct ScaledEpilogue
...
@@ -122,8 +122,8 @@ struct ScaledEpilogue
auto
a_args
=
SUPER
::
template
args_from_tensor
<
ScaleA
,
float
>(
a_scales
);
auto
a_args
=
SUPER
::
template
args_from_tensor
<
ScaleA
,
float
>(
a_scales
);
auto
b_args
=
SUPER
::
template
args_from_tensor
<
ScaleB
,
float
>(
b_scales
);
auto
b_args
=
SUPER
::
template
args_from_tensor
<
ScaleB
,
float
>(
b_scales
);
typename
EVTCompute0
::
Arguments
evt0_args
{
b_args
};
typename
EVTCompute0
::
Arguments
evt0_args
{
b_args
,
{},
{}
};
return
ArgumentType
{
a_args
,
evt0_args
};
return
ArgumentType
{
a_args
,
evt0_args
,
{}
};
}
}
};
};
...
@@ -167,8 +167,8 @@ struct ScaledEpilogueBias
...
@@ -167,8 +167,8 @@ struct ScaledEpilogueBias
auto
b_args
=
SUPER
::
template
args_from_tensor
<
ScaleB
,
float
>(
b_scales
);
auto
b_args
=
SUPER
::
template
args_from_tensor
<
ScaleB
,
float
>(
b_scales
);
auto
bias_args
=
SUPER
::
template
args_from_tensor
<
Bias
,
ElementD
>(
bias
);
auto
bias_args
=
SUPER
::
template
args_from_tensor
<
Bias
,
ElementD
>(
bias
);
typename
EVTCompute0
::
Arguments
evt0_args
{
b_args
};
typename
EVTCompute0
::
Arguments
evt0_args
{
b_args
,
{},
{}
};
return
ArgumentType
{
a_args
,
evt0_args
,
bias_args
};
return
ArgumentType
{
a_args
,
evt0_args
,
bias_args
,
{}
};
}
}
};
};
...
@@ -230,9 +230,10 @@ struct ScaledEpilogueBiasAzp
...
@@ -230,9 +230,10 @@ struct ScaledEpilogueBiasAzp
auto
azp_adj_args
=
auto
azp_adj_args
=
SUPER
::
template
args_from_tensor
<
AzpWithAdj
,
int32_t
>(
azp_adj
);
SUPER
::
template
args_from_tensor
<
AzpWithAdj
,
int32_t
>(
azp_adj
);
typename
EVTComputeAzp
::
Arguments
evt_azp_args
{{},
azp_adj_args
};
typename
EVTComputeAzp
::
Arguments
evt_azp_args
{{},
azp_adj_args
,
{}};
typename
EVTComputeScaleB
::
Arguments
evt_scale_b_args
{
b_args
,
evt_azp_args
};
typename
EVTComputeScaleB
::
Arguments
evt_scale_b_args
{
return
ArgumentType
{
a_args
,
evt_scale_b_args
,
bias_args
};
b_args
,
evt_azp_args
,
{}};
return
ArgumentType
{
a_args
,
evt_scale_b_args
,
bias_args
,
{}};
}
}
};
};
...
@@ -309,11 +310,12 @@ struct ScaledEpilogueBiasAzpToken
...
@@ -309,11 +310,12 @@ struct ScaledEpilogueBiasAzpToken
auto
azp_adj_args
=
auto
azp_adj_args
=
SUPER
::
template
args_from_tensor
<
AzpAdj
,
int32_t
>(
azp_adj
);
SUPER
::
template
args_from_tensor
<
AzpAdj
,
int32_t
>(
azp_adj
);
typename
EVTComputeAzp
::
Arguments
evt_azp_args
{
azp_args
,
azp_adj_args
};
typename
EVTComputeAzp
::
Arguments
evt_azp_args
{
azp_args
,
azp_adj_args
,
{}};
typename
EVTComputeAcc
::
Arguments
evt_acc_args
{{},
evt_azp_args
};
typename
EVTComputeAcc
::
Arguments
evt_acc_args
{{},
evt_azp_args
,
{}};
typename
EVTComputeScaleB
::
Arguments
evt_scale_b_args
{
b_args
,
evt_acc_args
};
typename
EVTComputeScaleB
::
Arguments
evt_scale_b_args
{
return
ArgumentType
{
a_args
,
evt_scale_b_args
,
bias_args
};
b_args
,
evt_acc_args
,
{}};
return
ArgumentType
{
a_args
,
evt_scale_b_args
,
bias_args
,
{}};
}
}
};
};
};
// namespace vllm::c2x
};
// namespace vllm::c2x
\ No newline at end of file
csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp
View file @
469e903b
...
@@ -22,7 +22,7 @@ struct identity {
...
@@ -22,7 +22,7 @@ struct identity {
T
operator
()(
T
lhs
)
const
{
return
lhs
;
}
T
operator
()(
T
lhs
)
const
{
return
lhs
;
}
};
};
template
<
typename
ElementAcc
,
typename
ElementD
,
typename
EpilogueDescriptor
>
template
<
typename
ElementAcc
,
typename
ElementD
,
typename
TileShape
>
struct
TrivialEpilogue
{
struct
TrivialEpilogue
{
private:
private:
using
Accum
=
cutlass
::
epilogue
::
fusion
::
Sm90AccFetch
;
using
Accum
=
cutlass
::
epilogue
::
fusion
::
Sm90AccFetch
;
...
@@ -44,32 +44,30 @@ struct TrivialEpilogue {
...
@@ -44,32 +44,30 @@ struct TrivialEpilogue {
* This class provides the common load descriptors for the
* This class provides the common load descriptors for the
* ScaledEpilogue[...] classes
* ScaledEpilogue[...] classes
*/
*/
template
<
typename
ElementAcc
,
typename
ElementD
,
typename
EpilogueDescriptor
>
template
<
typename
ElementAcc
,
typename
ElementD
,
typename
TileShape
>
struct
ScaledEpilogueBase
{
struct
ScaledEpilogueBase
{
protected:
protected:
using
Accum
=
cutlass
::
epilogue
::
fusion
::
Sm90AccFetch
;
using
Accum
=
cutlass
::
epilogue
::
fusion
::
Sm90AccFetch
;
template
<
typename
T
>
template
<
typename
T
>
using
ColOrScalarLoad
=
cutlass
::
epilogue
::
fusion
::
Sm90ColOrScalarBroadcast
<
using
ColOrScalarLoad
=
cutlass
::
epilogue
::
fusion
::
Sm90ColOrScalarBroadcast
<
0
/*Stages*/
,
typename
EpilogueDescriptor
::
TileShape
,
T
,
0
/*Stages*/
,
TileShape
,
T
,
Stride
<
Int
<
1
>
,
Int
<
0
>
,
Int
<
0
>>>
;
Stride
<
Int
<
1
>
,
Int
<
0
>
,
Int
<
0
>>>
;
template
<
typename
T
>
template
<
typename
T
>
using
RowOrScalarLoad
=
cutlass
::
epilogue
::
fusion
::
Sm90RowOrScalarBroadcast
<
using
RowOrScalarLoad
=
cutlass
::
epilogue
::
fusion
::
Sm90RowOrScalarBroadcast
<
0
/*Stages*/
,
typename
EpilogueDescriptor
::
TileShape
,
T
,
0
/*Stages*/
,
TileShape
,
T
,
Stride
<
Int
<
0
>
,
Int
<
1
>
,
Int
<
0
>>>
;
Stride
<
Int
<
0
>
,
Int
<
1
>
,
Int
<
0
>>>
;
// Don't want to support nullptr by default
// Don't want to support nullptr by default
template
<
typename
T
,
bool
EnableNullPtr
=
false
>
template
<
typename
T
,
bool
EnableNullPtr
=
false
>
using
ColLoad
=
cutlass
::
epilogue
::
fusion
::
Sm90ColBroadcast
<
using
ColLoad
=
cutlass
::
epilogue
::
fusion
::
Sm90ColBroadcast
<
0
/*Stages*/
,
typename
EpilogueDescriptor
::
TileShape
,
T
,
T
,
0
/*Stages*/
,
TileShape
,
T
,
T
,
Stride
<
Int
<
1
>
,
Int
<
0
>
,
Int
<
0
>>
,
Stride
<
Int
<
1
>
,
Int
<
0
>
,
Int
<
0
>>
,
128
/
sizeof_bits_v
<
T
>
,
EnableNullPtr
>
;
128
/
sizeof_bits_v
<
T
>
,
EnableNullPtr
>
;
// Don't want to support nullptr by default
// Don't want to support nullptr by default
template
<
typename
T
,
bool
EnableNullPtr
=
false
>
template
<
typename
T
,
bool
EnableNullPtr
=
false
>
using
RowLoad
=
cutlass
::
epilogue
::
fusion
::
Sm90RowBroadcast
<
using
RowLoad
=
cutlass
::
epilogue
::
fusion
::
Sm90RowBroadcast
<
0
/*Stages*/
,
typename
EpilogueDescriptor
::
TileShape
,
T
,
T
,
0
/*Stages*/
,
TileShape
,
T
,
T
,
Stride
<
Int
<
0
>
,
Int
<
1
>
,
Int
<
0
>>
,
Stride
<
Int
<
0
>
,
Int
<
1
>
,
Int
<
0
>>
,
128
/
sizeof_bits_v
<
T
>
,
EnableNullPtr
>
;
128
/
sizeof_bits_v
<
T
>
,
EnableNullPtr
>
;
// This utility function constructs the arguments for the load descriptors
// This utility function constructs the arguments for the load descriptors
// from a tensor. It can handle both row and column, as well as row/column or
// from a tensor. It can handle both row and column, as well as row/column or
...
@@ -116,11 +114,11 @@ struct ScaledEpilogueBase {
...
@@ -116,11 +114,11 @@ struct ScaledEpilogueBase {
the A and B operands respectively. These scales may be either per-tensor or
the A and B operands respectively. These scales may be either per-tensor or
per row or column.
per row or column.
*/
*/
template
<
typename
ElementAcc
,
typename
ElementD
,
typename
EpilogueDescriptor
>
template
<
typename
ElementAcc
,
typename
ElementD
,
typename
TileShape
>
struct
ScaledEpilogue
struct
ScaledEpilogue
:
private
ScaledEpilogueBase
<
ElementAcc
,
ElementD
,
EpilogueDescriptor
>
{
:
private
ScaledEpilogueBase
<
ElementAcc
,
ElementD
,
TileShape
>
{
private:
private:
using
SUPER
=
ScaledEpilogueBase
<
ElementAcc
,
ElementD
,
EpilogueDescriptor
>
;
using
SUPER
=
ScaledEpilogueBase
<
ElementAcc
,
ElementD
,
TileShape
>
;
using
Accum
=
typename
SUPER
::
Accum
;
using
Accum
=
typename
SUPER
::
Accum
;
using
ScaleA
=
typename
SUPER
::
template
ColOrScalarLoad
<
float
>;
using
ScaleA
=
typename
SUPER
::
template
ColOrScalarLoad
<
float
>;
using
ScaleB
=
typename
SUPER
::
template
RowOrScalarLoad
<
float
>;
using
ScaleB
=
typename
SUPER
::
template
RowOrScalarLoad
<
float
>;
...
@@ -146,8 +144,8 @@ struct ScaledEpilogue
...
@@ -146,8 +144,8 @@ struct ScaledEpilogue
auto
a_args
=
SUPER
::
template
args_from_tensor
<
ScaleA
,
float
>(
a_scales
);
auto
a_args
=
SUPER
::
template
args_from_tensor
<
ScaleA
,
float
>(
a_scales
);
auto
b_args
=
SUPER
::
template
args_from_tensor
<
ScaleB
,
float
>(
b_scales
);
auto
b_args
=
SUPER
::
template
args_from_tensor
<
ScaleB
,
float
>(
b_scales
);
typename
EVTCompute0
::
Arguments
evt0_args
{
b_args
};
typename
EVTCompute0
::
Arguments
evt0_args
{
b_args
,
{},
{}
};
return
ArgumentType
{
a_args
,
evt0_args
};
return
ArgumentType
{
a_args
,
evt0_args
,
{}
};
}
}
};
};
...
@@ -160,11 +158,11 @@ struct ScaledEpilogue
...
@@ -160,11 +158,11 @@ struct ScaledEpilogue
* The bias tensor must be per-output channel.
* The bias tensor must be per-output channel.
* ScaleA and ScaleB can be per-tensor or per-token/per-channel.
* ScaleA and ScaleB can be per-tensor or per-token/per-channel.
*/
*/
template
<
typename
ElementAcc
,
typename
ElementD
,
typename
EpilogueDescriptor
>
template
<
typename
ElementAcc
,
typename
ElementD
,
typename
TileShape
>
struct
ScaledEpilogueBias
struct
ScaledEpilogueBias
:
private
ScaledEpilogueBase
<
ElementAcc
,
ElementD
,
EpilogueDescriptor
>
{
:
private
ScaledEpilogueBase
<
ElementAcc
,
ElementD
,
TileShape
>
{
private:
private:
using
SUPER
=
ScaledEpilogueBase
<
ElementAcc
,
ElementD
,
EpilogueDescriptor
>
;
using
SUPER
=
ScaledEpilogueBase
<
ElementAcc
,
ElementD
,
TileShape
>
;
using
Accum
=
typename
SUPER
::
Accum
;
using
Accum
=
typename
SUPER
::
Accum
;
using
ScaleA
=
typename
SUPER
::
template
ColOrScalarLoad
<
float
>;
using
ScaleA
=
typename
SUPER
::
template
ColOrScalarLoad
<
float
>;
using
ScaleB
=
typename
SUPER
::
template
RowOrScalarLoad
<
float
>;
using
ScaleB
=
typename
SUPER
::
template
RowOrScalarLoad
<
float
>;
...
@@ -193,8 +191,8 @@ struct ScaledEpilogueBias
...
@@ -193,8 +191,8 @@ struct ScaledEpilogueBias
auto
b_args
=
SUPER
::
template
args_from_tensor
<
ScaleB
,
float
>(
b_scales
);
auto
b_args
=
SUPER
::
template
args_from_tensor
<
ScaleB
,
float
>(
b_scales
);
auto
bias_args
=
SUPER
::
template
args_from_tensor
<
Bias
,
ElementD
>(
bias
);
auto
bias_args
=
SUPER
::
template
args_from_tensor
<
Bias
,
ElementD
>(
bias
);
typename
EVTCompute0
::
Arguments
evt0_args
{
b_args
};
typename
EVTCompute0
::
Arguments
evt0_args
{
b_args
,
{},
{}
};
return
ArgumentType
{
a_args
,
evt0_args
,
bias_args
};
return
ArgumentType
{
a_args
,
evt0_args
,
bias_args
,
{}
};
}
}
};
};
...
@@ -203,11 +201,11 @@ struct ScaledEpilogueBias
...
@@ -203,11 +201,11 @@ struct ScaledEpilogueBias
* bias is a column vector instead of a row vector. Useful e.g. if we are
* bias is a column vector instead of a row vector. Useful e.g. if we are
* computing a GEMM via C^T += B^T A^T. This happens in the 2:4 sparse kernels.
* computing a GEMM via C^T += B^T A^T. This happens in the 2:4 sparse kernels.
*/
*/
template
<
typename
ElementAcc
,
typename
ElementD
,
typename
EpilogueDescriptor
>
template
<
typename
ElementAcc
,
typename
ElementD
,
typename
TileShape
>
struct
ScaledEpilogueColumnBias
struct
ScaledEpilogueColumnBias
:
private
ScaledEpilogueBase
<
ElementAcc
,
ElementD
,
EpilogueDescriptor
>
{
:
private
ScaledEpilogueBase
<
ElementAcc
,
ElementD
,
TileShape
>
{
private:
private:
using
SUPER
=
ScaledEpilogueBase
<
ElementAcc
,
ElementD
,
EpilogueDescriptor
>
;
using
SUPER
=
ScaledEpilogueBase
<
ElementAcc
,
ElementD
,
TileShape
>
;
using
Accum
=
typename
SUPER
::
Accum
;
using
Accum
=
typename
SUPER
::
Accum
;
using
ScaleA
=
typename
SUPER
::
template
ColOrScalarLoad
<
float
>;
using
ScaleA
=
typename
SUPER
::
template
ColOrScalarLoad
<
float
>;
using
ScaleB
=
typename
SUPER
::
template
RowOrScalarLoad
<
float
>;
using
ScaleB
=
typename
SUPER
::
template
RowOrScalarLoad
<
float
>;
...
@@ -236,8 +234,8 @@ struct ScaledEpilogueColumnBias
...
@@ -236,8 +234,8 @@ struct ScaledEpilogueColumnBias
auto
b_args
=
SUPER
::
template
args_from_tensor
<
ScaleB
,
float
>(
b_scales
);
auto
b_args
=
SUPER
::
template
args_from_tensor
<
ScaleB
,
float
>(
b_scales
);
auto
bias_args
=
SUPER
::
template
args_from_tensor
<
Bias
,
ElementD
>(
bias
);
auto
bias_args
=
SUPER
::
template
args_from_tensor
<
Bias
,
ElementD
>(
bias
);
typename
EVTCompute0
::
Arguments
evt0_args
{
b_args
};
typename
EVTCompute0
::
Arguments
evt0_args
{
b_args
,
{},
{}
};
return
ArgumentType
{
a_args
,
evt0_args
,
bias_args
};
return
ArgumentType
{
a_args
,
evt0_args
,
bias_args
,
{}
};
}
}
};
};
...
@@ -249,11 +247,11 @@ struct ScaledEpilogueColumnBias
...
@@ -249,11 +247,11 @@ struct ScaledEpilogueColumnBias
*
*
* This epilogue also supports bias, which remains per-channel.
* This epilogue also supports bias, which remains per-channel.
*/
*/
template
<
typename
ElementAcc
,
typename
ElementD
,
typename
EpilogueDescriptor
>
template
<
typename
ElementAcc
,
typename
ElementD
,
typename
TileShape
>
struct
ScaledEpilogueBiasAzp
struct
ScaledEpilogueBiasAzp
:
private
ScaledEpilogueBase
<
ElementAcc
,
ElementD
,
EpilogueDescriptor
>
{
:
private
ScaledEpilogueBase
<
ElementAcc
,
ElementD
,
TileShape
>
{
private:
private:
using
SUPER
=
ScaledEpilogueBase
<
ElementAcc
,
ElementD
,
EpilogueDescriptor
>
;
using
SUPER
=
ScaledEpilogueBase
<
ElementAcc
,
ElementD
,
TileShape
>
;
using
Accum
=
typename
SUPER
::
Accum
;
using
Accum
=
typename
SUPER
::
Accum
;
using
ScaleA
=
typename
SUPER
::
template
ColOrScalarLoad
<
float
>;
using
ScaleA
=
typename
SUPER
::
template
ColOrScalarLoad
<
float
>;
using
ScaleB
=
typename
SUPER
::
template
RowOrScalarLoad
<
float
>;
using
ScaleB
=
typename
SUPER
::
template
RowOrScalarLoad
<
float
>;
...
@@ -297,9 +295,10 @@ struct ScaledEpilogueBiasAzp
...
@@ -297,9 +295,10 @@ struct ScaledEpilogueBiasAzp
auto
azp_adj_args
=
auto
azp_adj_args
=
SUPER
::
template
args_from_tensor
<
AzpWithAdj
,
int32_t
>(
azp_adj
);
SUPER
::
template
args_from_tensor
<
AzpWithAdj
,
int32_t
>(
azp_adj
);
typename
EVTComputeAzp
::
Arguments
evt_azp_args
{{},
azp_adj_args
};
typename
EVTComputeAzp
::
Arguments
evt_azp_args
{{},
azp_adj_args
,
{}};
typename
EVTComputeScaleB
::
Arguments
evt_scale_b_args
{
b_args
,
evt_azp_args
};
typename
EVTComputeScaleB
::
Arguments
evt_scale_b_args
{
return
ArgumentType
{
a_args
,
evt_scale_b_args
,
bias_args
};
b_args
,
evt_azp_args
,
{}};
return
ArgumentType
{
a_args
,
evt_scale_b_args
,
bias_args
,
{}};
}
}
};
};
...
@@ -313,11 +312,11 @@ struct ScaledEpilogueBiasAzp
...
@@ -313,11 +312,11 @@ struct ScaledEpilogueBiasAzp
*
*
* This epilogue also supports bias, which remains per-channel.
* This epilogue also supports bias, which remains per-channel.
*/
*/
template
<
typename
ElementAcc
,
typename
ElementD
,
typename
EpilogueDescriptor
>
template
<
typename
ElementAcc
,
typename
ElementD
,
typename
TileShape
>
struct
ScaledEpilogueBiasAzpToken
struct
ScaledEpilogueBiasAzpToken
:
private
ScaledEpilogueBase
<
ElementAcc
,
ElementD
,
EpilogueDescriptor
>
{
:
private
ScaledEpilogueBase
<
ElementAcc
,
ElementD
,
TileShape
>
{
private:
private:
using
SUPER
=
ScaledEpilogueBase
<
ElementAcc
,
ElementD
,
EpilogueDescriptor
>
;
using
SUPER
=
ScaledEpilogueBase
<
ElementAcc
,
ElementD
,
TileShape
>
;
using
Accum
=
typename
SUPER
::
Accum
;
using
Accum
=
typename
SUPER
::
Accum
;
using
ScaleA
=
typename
SUPER
::
template
ColOrScalarLoad
<
float
>;
using
ScaleA
=
typename
SUPER
::
template
ColOrScalarLoad
<
float
>;
using
ScaleB
=
typename
SUPER
::
template
RowOrScalarLoad
<
float
>;
using
ScaleB
=
typename
SUPER
::
template
RowOrScalarLoad
<
float
>;
...
@@ -374,10 +373,11 @@ struct ScaledEpilogueBiasAzpToken
...
@@ -374,10 +373,11 @@ struct ScaledEpilogueBiasAzpToken
auto
azp_adj_args
=
auto
azp_adj_args
=
SUPER
::
template
args_from_tensor
<
AzpAdj
,
int32_t
>(
azp_adj
);
SUPER
::
template
args_from_tensor
<
AzpAdj
,
int32_t
>(
azp_adj
);
typename
EVTComputeAzp
::
Arguments
evt_azp_args
{
azp_args
,
azp_adj_args
};
typename
EVTComputeAzp
::
Arguments
evt_azp_args
{
azp_args
,
azp_adj_args
,
{}};
typename
EVTComputeAcc
::
Arguments
evt_acc_args
{{},
evt_azp_args
};
typename
EVTComputeAcc
::
Arguments
evt_acc_args
{{},
evt_azp_args
,
{}};
typename
EVTComputeScaleB
::
Arguments
evt_scale_b_args
{
b_args
,
evt_acc_args
};
typename
EVTComputeScaleB
::
Arguments
evt_scale_b_args
{
return
ArgumentType
{
a_args
,
evt_scale_b_args
,
bias_args
};
b_args
,
evt_acc_args
,
{}};
return
ArgumentType
{
a_args
,
evt_scale_b_args
,
bias_args
,
{}};
}
}
};
};
...
...
csrc/cutlass_extensions/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8_blockwise_scaling.hpp
View file @
469e903b
...
@@ -402,7 +402,7 @@ struct CollectiveMma<
...
@@ -402,7 +402,7 @@ struct CollectiveMma<
// TODO: test `scale_copy_a` with `ScaleMsPerTile` < 128
// TODO: test `scale_copy_a` with `ScaleMsPerTile` < 128
TiledCopy
scale_copy_a
=
make_tiled_copy
(
SmemBlockScalingCopyAtomA
{},
TiledCopy
scale_copy_a
=
make_tiled_copy
(
SmemBlockScalingCopyAtomA
{},
Layout
<
Shape
<
_32
,
_1
>>
{},
Layout
<
Shape
<
_4
,
_1
>>
{});
// (1,1,1)
Layout
<
Shape
<
_32
>>
{},
Layout
<
Shape
<
_1
>>
{});
// (1,1,1)
TiledCopy
scale_copy_b
=
make_tiled_copy
(
SmemBlockScalingCopyAtomB
{},
TiledCopy
scale_copy_b
=
make_tiled_copy
(
SmemBlockScalingCopyAtomB
{},
Layout
<
Shape
<
_1
>>
{},
Layout
<
Shape
<
_1
>>
{});
// (1,1,1)
Layout
<
Shape
<
_1
>>
{},
Layout
<
Shape
<
_1
>>
{});
// (1,1,1)
ThrCopy
thr_scale_copy_a
=
scale_copy_a
.
get_slice
(
threadIdx
.
x
);
ThrCopy
thr_scale_copy_a
=
scale_copy_a
.
get_slice
(
threadIdx
.
x
);
...
...
csrc/cutlass_extensions/vllm_cutlass_library_extension.py
View file @
469e903b
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
import
enum
import
enum
from
typing
import
Dict
,
Union
from
typing
import
Union
from
cutlass_library
import
*
from
cutlass_library
import
*
...
@@ -21,7 +21,7 @@ class MixedInputKernelScheduleType(enum.Enum):
...
@@ -21,7 +21,7 @@ class MixedInputKernelScheduleType(enum.Enum):
TmaWarpSpecializedCooperative
=
enum_auto
()
TmaWarpSpecializedCooperative
=
enum_auto
()
VLLMDataTypeNames
:
D
ict
[
Union
[
VLLMDataType
,
DataType
],
str
]
=
{
VLLMDataTypeNames
:
d
ict
[
Union
[
VLLMDataType
,
DataType
],
str
]
=
{
**
DataTypeNames
,
# type: ignore
**
DataTypeNames
,
# type: ignore
**
{
**
{
VLLMDataType
.
u4b8
:
"u4b8"
,
VLLMDataType
.
u4b8
:
"u4b8"
,
...
@@ -29,7 +29,7 @@ VLLMDataTypeNames: Dict[Union[VLLMDataType, DataType], str] = {
...
@@ -29,7 +29,7 @@ VLLMDataTypeNames: Dict[Union[VLLMDataType, DataType], str] = {
}
}
}
}
VLLMDataTypeTag
:
D
ict
[
Union
[
VLLMDataType
,
DataType
],
str
]
=
{
VLLMDataTypeTag
:
d
ict
[
Union
[
VLLMDataType
,
DataType
],
str
]
=
{
**
DataTypeTag
,
# type: ignore
**
DataTypeTag
,
# type: ignore
**
{
**
{
VLLMDataType
.
u4b8
:
"cutlass::vllm_uint4b8_t"
,
VLLMDataType
.
u4b8
:
"cutlass::vllm_uint4b8_t"
,
...
@@ -37,7 +37,7 @@ VLLMDataTypeTag: Dict[Union[VLLMDataType, DataType], str] = {
...
@@ -37,7 +37,7 @@ VLLMDataTypeTag: Dict[Union[VLLMDataType, DataType], str] = {
}
}
}
}
VLLMDataTypeSize
:
D
ict
[
Union
[
VLLMDataType
,
DataType
],
int
]
=
{
VLLMDataTypeSize
:
d
ict
[
Union
[
VLLMDataType
,
DataType
],
int
]
=
{
**
DataTypeSize
,
# type: ignore
**
DataTypeSize
,
# type: ignore
**
{
**
{
VLLMDataType
.
u4b8
:
4
,
VLLMDataType
.
u4b8
:
4
,
...
@@ -45,7 +45,7 @@ VLLMDataTypeSize: Dict[Union[VLLMDataType, DataType], int] = {
...
@@ -45,7 +45,7 @@ VLLMDataTypeSize: Dict[Union[VLLMDataType, DataType], int] = {
}
}
}
}
VLLMDataTypeVLLMScalarTypeTag
:
D
ict
[
Union
[
VLLMDataType
,
DataType
],
str
]
=
{
VLLMDataTypeVLLMScalarTypeTag
:
d
ict
[
Union
[
VLLMDataType
,
DataType
],
str
]
=
{
VLLMDataType
.
u4b8
:
"vllm::kU4B8"
,
VLLMDataType
.
u4b8
:
"vllm::kU4B8"
,
VLLMDataType
.
u8b128
:
"vllm::kU8B128"
,
VLLMDataType
.
u8b128
:
"vllm::kU8B128"
,
DataType
.
u4
:
"vllm::kU4"
,
DataType
.
u4
:
"vllm::kU4"
,
...
@@ -56,7 +56,7 @@ VLLMDataTypeVLLMScalarTypeTag: Dict[Union[VLLMDataType, DataType], str] = {
...
@@ -56,7 +56,7 @@ VLLMDataTypeVLLMScalarTypeTag: Dict[Union[VLLMDataType, DataType], str] = {
DataType
.
bf16
:
"vllm::kBfloat16"
,
DataType
.
bf16
:
"vllm::kBfloat16"
,
}
}
VLLMDataTypeTorchDataTypeTag
:
D
ict
[
Union
[
VLLMDataType
,
DataType
],
str
]
=
{
VLLMDataTypeTorchDataTypeTag
:
d
ict
[
Union
[
VLLMDataType
,
DataType
],
str
]
=
{
DataType
.
u8
:
"at::ScalarType::Byte"
,
DataType
.
u8
:
"at::ScalarType::Byte"
,
DataType
.
s8
:
"at::ScalarType::Char"
,
DataType
.
s8
:
"at::ScalarType::Char"
,
DataType
.
e4m3
:
"at::ScalarType::Float8_e4m3fn"
,
DataType
.
e4m3
:
"at::ScalarType::Float8_e4m3fn"
,
...
@@ -66,7 +66,7 @@ VLLMDataTypeTorchDataTypeTag: Dict[Union[VLLMDataType, DataType], str] = {
...
@@ -66,7 +66,7 @@ VLLMDataTypeTorchDataTypeTag: Dict[Union[VLLMDataType, DataType], str] = {
DataType
.
f32
:
"at::ScalarType::Float"
,
DataType
.
f32
:
"at::ScalarType::Float"
,
}
}
VLLMKernelScheduleTag
:
D
ict
[
Union
[
VLLMKernelScheduleTag
:
d
ict
[
Union
[
MixedInputKernelScheduleType
,
KernelScheduleType
],
str
]
=
{
MixedInputKernelScheduleType
,
KernelScheduleType
],
str
]
=
{
**
KernelScheduleTag
,
# type: ignore
**
KernelScheduleTag
,
# type: ignore
**
{
**
{
...
...
csrc/dispatch_utils.h
View file @
469e903b
...
@@ -6,6 +6,11 @@
...
@@ -6,6 +6,11 @@
#include <torch/all.h>
#include <torch/all.h>
// Need a special dispatch case macro since we will nest the FP8 dispatch.
// Instead of the usual 'scalar_t', this names the dispatched type 'fp8_t'.
#define AT_DISPATCH_FP8_CASE(enum_type, ...) \
AT_PRIVATE_CASE_TYPE_USING_HINT(enum_type, fp8_t, __VA_ARGS__)
#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \
#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \
...
@@ -14,17 +19,32 @@
...
@@ -14,17 +19,32 @@
#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))
AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))
// TODO(luka/varun): use FP8_TYPE macro after refactoring
// ROCm devices might use either fn or fnuz, so set up dispatch table for both.
#ifndef USE_ROCM
// A host-based check at runtime will create a preferred FP8 type for ROCm
#define VLLM_DISPATCH_CASE_QUANT_TYPES(...) \
// such that the correct kernel is dispatched.
AT_DISPATCH_CASE(at::ScalarType::Float8_e4m3fn, __VA_ARGS__) \
#ifdef USE_ROCM
AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__)
#define VLLM_DISPATCH_CASE_FP8_TYPES(...) \
#else
AT_DISPATCH_FP8_CASE(at::ScalarType::Float8_e4m3fn, __VA_ARGS__) \
AT_DISPATCH_FP8_CASE(at::ScalarType::Float8_e4m3fnuz, __VA_ARGS__)
#define VLLM_DISPATCH_CASE_QUANT_TYPES(...) \
#define VLLM_DISPATCH_CASE_QUANT_TYPES(...) \
AT_DISPATCH_CASE(at::ScalarType::Float8_e4m3fn, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Float8_e4m3fnuz, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Float8_e4m3fnuz, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__)
AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__)
#else
#define VLLM_DISPATCH_CASE_FP8_TYPES(...) \
AT_DISPATCH_FP8_CASE(at::ScalarType::Float8_e4m3fn, __VA_ARGS__)
#define VLLM_DISPATCH_CASE_QUANT_TYPES(...) \
AT_DISPATCH_CASE(at::ScalarType::Float8_e4m3fn, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__)
#endif
#endif
// When using this dispatch macro, the type is 'fp8_t' not 'scalar_t'.
// See AT_DISPATCH_FP8_CASE above.
#define VLLM_DISPATCH_FP8_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FP8_TYPES(__VA_ARGS__))
#define VLLM_DISPATCH_QUANT_TYPES(TYPE, NAME, ...) \
#define VLLM_DISPATCH_QUANT_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_QUANT_TYPES(__VA_ARGS__))
AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_QUANT_TYPES(__VA_ARGS__))
...
...
csrc/layernorm_quant_kernels.cu
View file @
469e903b
...
@@ -21,9 +21,9 @@
...
@@ -21,9 +21,9 @@
namespace
vllm
{
namespace
vllm
{
// TODO(woosuk): Further optimize this kernel.
// TODO(woosuk): Further optimize this kernel.
template
<
typename
scalar_t
>
template
<
typename
scalar_t
,
typename
fp8_type
>
__global__
void
rms_norm_static_fp8_quant_kernel
(
__global__
void
rms_norm_static_fp8_quant_kernel
(
FP8_TYPE
*
__restrict__
out
,
// [..., hidden_size]
fp8_type
*
__restrict__
out
,
// [..., hidden_size]
const
scalar_t
*
__restrict__
input
,
// [..., hidden_size]
const
scalar_t
*
__restrict__
input
,
// [..., hidden_size]
const
scalar_t
*
__restrict__
weight
,
// [hidden_size]
const
scalar_t
*
__restrict__
weight
,
// [hidden_size]
const
float
*
__restrict__
scale
,
// [1]
const
float
*
__restrict__
scale
,
// [1]
...
@@ -52,7 +52,7 @@ __global__ void rms_norm_static_fp8_quant_kernel(
...
@@ -52,7 +52,7 @@ __global__ void rms_norm_static_fp8_quant_kernel(
float
x
=
(
float
)
input
[
blockIdx
.
x
*
hidden_size
+
idx
];
float
x
=
(
float
)
input
[
blockIdx
.
x
*
hidden_size
+
idx
];
float
const
out_norm
=
((
scalar_t
)(
x
*
s_variance
))
*
weight
[
idx
];
float
const
out_norm
=
((
scalar_t
)(
x
*
s_variance
))
*
weight
[
idx
];
out
[
blockIdx
.
x
*
hidden_size
+
idx
]
=
out
[
blockIdx
.
x
*
hidden_size
+
idx
]
=
scaled_fp8_conversion
<
true
>
(
out_norm
,
scale_inv
);
scaled_fp8_conversion
<
true
,
fp8_type
>
(
out_norm
,
scale_inv
);
}
}
}
}
...
@@ -60,10 +60,10 @@ __global__ void rms_norm_static_fp8_quant_kernel(
...
@@ -60,10 +60,10 @@ __global__ void rms_norm_static_fp8_quant_kernel(
Additional optimizations we can make in this case are
Additional optimizations we can make in this case are
packed and vectorized operations, which help with the
packed and vectorized operations, which help with the
memory latency bottleneck. */
memory latency bottleneck. */
template
<
typename
scalar_t
,
int
width
>
template
<
typename
scalar_t
,
int
width
,
typename
fp8_type
>
__global__
std
::
enable_if_t
<
(
width
>
0
)
&&
_typeConvert
<
scalar_t
>::
exists
>
__global__
std
::
enable_if_t
<
(
width
>
0
)
&&
_typeConvert
<
scalar_t
>::
exists
>
fused_add_rms_norm_static_fp8_quant_kernel
(
fused_add_rms_norm_static_fp8_quant_kernel
(
FP8_TYPE
*
__restrict__
out
,
// [..., hidden_size]
fp8_type
*
__restrict__
out
,
// [..., hidden_size]
scalar_t
*
__restrict__
input
,
// [..., hidden_size]
scalar_t
*
__restrict__
input
,
// [..., hidden_size]
scalar_t
*
__restrict__
residual
,
// [..., hidden_size]
scalar_t
*
__restrict__
residual
,
// [..., hidden_size]
const
scalar_t
*
__restrict__
weight
,
// [hidden_size]
const
scalar_t
*
__restrict__
weight
,
// [hidden_size]
...
@@ -114,7 +114,7 @@ fused_add_rms_norm_static_fp8_quant_kernel(
...
@@ -114,7 +114,7 @@ fused_add_rms_norm_static_fp8_quant_kernel(
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
width
;
++
i
)
{
for
(
int
i
=
0
;
i
<
width
;
++
i
)
{
out
[
id
*
width
+
i
]
=
out
[
id
*
width
+
i
]
=
scaled_fp8_conversion
<
true
>
(
float
(
temp
.
data
[
i
]),
scale_inv
);
scaled_fp8_conversion
<
true
,
fp8_type
>
(
float
(
temp
.
data
[
i
]),
scale_inv
);
}
}
}
}
}
}
...
@@ -122,10 +122,10 @@ fused_add_rms_norm_static_fp8_quant_kernel(
...
@@ -122,10 +122,10 @@ fused_add_rms_norm_static_fp8_quant_kernel(
/* Generic fused_add_rms_norm_kernel
/* Generic fused_add_rms_norm_kernel
The width field is not used here but necessary for other specializations.
The width field is not used here but necessary for other specializations.
*/
*/
template
<
typename
scalar_t
,
int
width
>
template
<
typename
scalar_t
,
int
width
,
typename
fp8_type
>
__global__
std
::
enable_if_t
<
(
width
==
0
)
||
!
_typeConvert
<
scalar_t
>::
exists
>
__global__
std
::
enable_if_t
<
(
width
==
0
)
||
!
_typeConvert
<
scalar_t
>::
exists
>
fused_add_rms_norm_static_fp8_quant_kernel
(
fused_add_rms_norm_static_fp8_quant_kernel
(
FP8_TYPE
*
__restrict__
out
,
// [..., hidden_size]
fp8_type
*
__restrict__
out
,
// [..., hidden_size]
scalar_t
*
__restrict__
input
,
// [..., hidden_size]
scalar_t
*
__restrict__
input
,
// [..., hidden_size]
scalar_t
*
__restrict__
residual
,
// [..., hidden_size]
scalar_t
*
__restrict__
residual
,
// [..., hidden_size]
const
scalar_t
*
__restrict__
weight
,
// [hidden_size]
const
scalar_t
*
__restrict__
weight
,
// [hidden_size]
...
@@ -158,7 +158,7 @@ fused_add_rms_norm_static_fp8_quant_kernel(
...
@@ -158,7 +158,7 @@ fused_add_rms_norm_static_fp8_quant_kernel(
float
x
=
(
float
)
residual
[
blockIdx
.
x
*
hidden_size
+
idx
];
float
x
=
(
float
)
residual
[
blockIdx
.
x
*
hidden_size
+
idx
];
float
const
out_norm
=
((
scalar_t
)(
x
*
s_variance
))
*
weight
[
idx
];
float
const
out_norm
=
((
scalar_t
)(
x
*
s_variance
))
*
weight
[
idx
];
out
[
blockIdx
.
x
*
hidden_size
+
idx
]
=
out
[
blockIdx
.
x
*
hidden_size
+
idx
]
=
scaled_fp8_conversion
<
true
>
(
out_norm
,
scale_inv
);
scaled_fp8_conversion
<
true
,
fp8_type
>
(
out_norm
,
scale_inv
);
}
}
}
}
...
@@ -176,25 +176,33 @@ void rms_norm_static_fp8_quant(torch::Tensor& out, // [..., hidden_size]
...
@@ -176,25 +176,33 @@ void rms_norm_static_fp8_quant(torch::Tensor& out, // [..., hidden_size]
dim3
block
(
std
::
min
(
hidden_size
,
1024
));
dim3
block
(
std
::
min
(
hidden_size
,
1024
));
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
input
));
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
input
));
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
VLLM_DISPATCH_FLOATING_TYPES
(
input
.
scalar_type
(),
"rms_norm_kernel"
,
[
&
]
{
VLLM_DISPATCH_FLOATING_TYPES
(
vllm
::
rms_norm_static_fp8_quant_kernel
<
scalar_t
>
input
.
scalar_type
(),
"rms_norm_kernel_scalar_type"
,
[
&
]
{
<<<
grid
,
block
,
0
,
stream
>>>
(
VLLM_DISPATCH_FP8_TYPES
(
out
.
data_ptr
<
FP8_TYPE
>
(),
input
.
data_ptr
<
scalar_t
>
(),
out
.
scalar_type
(),
"rms_norm_kernel_fp8_type"
,
[
&
]
{
weight
.
data_ptr
<
scalar_t
>
(),
scale
.
data_ptr
<
float
>
(),
epsilon
,
vllm
::
rms_norm_static_fp8_quant_kernel
<
scalar_t
,
fp8_t
>
num_tokens
,
hidden_size
);
<<<
grid
,
block
,
0
,
stream
>>>
(
});
out
.
data_ptr
<
fp8_t
>
(),
input
.
data_ptr
<
scalar_t
>
(),
weight
.
data_ptr
<
scalar_t
>
(),
scale
.
data_ptr
<
float
>
(),
epsilon
,
num_tokens
,
hidden_size
);
});
});
}
}
#define LAUNCH_FUSED_ADD_RMS_NORM(width) \
#define LAUNCH_FUSED_ADD_RMS_NORM(width) \
VLLM_DISPATCH_FLOATING_TYPES( \
VLLM_DISPATCH_FLOATING_TYPES( \
input.scalar_type(), "fused_add_rms_norm_kernel", [&] { \
input.scalar_type(), "fused_add_rms_norm_kernel_scalar_type", [&] { \
vllm::fused_add_rms_norm_static_fp8_quant_kernel<scalar_t, width> \
VLLM_DISPATCH_FP8_TYPES( \
<<<grid, block, 0, stream>>>( \
out.scalar_type(), "fused_add_rms_norm_kernel_fp8_type", [&] { \
out.data_ptr<FP8_TYPE>(), input.data_ptr<scalar_t>(), \
vllm::fused_add_rms_norm_static_fp8_quant_kernel<scalar_t, \
residual.data_ptr<scalar_t>(), weight.data_ptr<scalar_t>(), \
width, fp8_t> \
scale.data_ptr<float>(), epsilon, num_tokens, hidden_size); \
<<<grid, block, 0, stream>>>( \
out.data_ptr<fp8_t>(), input.data_ptr<scalar_t>(), \
residual.data_ptr<scalar_t>(), \
weight.data_ptr<scalar_t>(), scale.data_ptr<float>(), \
epsilon, num_tokens, hidden_size); \
}); \
});
});
void
fused_add_rms_norm_static_fp8_quant
(
void
fused_add_rms_norm_static_fp8_quant
(
torch
::
Tensor
&
out
,
// [..., hidden_size],
torch
::
Tensor
&
out
,
// [..., hidden_size],
torch
::
Tensor
&
input
,
// [..., hidden_size]
torch
::
Tensor
&
input
,
// [..., hidden_size]
...
...
csrc/moe/moe_ops.h
View file @
469e903b
...
@@ -24,3 +24,14 @@ void sgl_moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
...
@@ -24,3 +24,14 @@ void sgl_moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
torch
::
Tensor
sorted_token_ids
,
torch
::
Tensor
sorted_token_ids
,
torch
::
Tensor
experts_ids
,
torch
::
Tensor
experts_ids
,
torch
::
Tensor
num_tokens_post_pad
);
torch
::
Tensor
num_tokens_post_pad
);
#ifndef USE_ROCM
torch
::
Tensor
moe_wna16_gemm
(
torch
::
Tensor
input
,
torch
::
Tensor
output
,
torch
::
Tensor
b_qweight
,
torch
::
Tensor
b_scales
,
std
::
optional
<
torch
::
Tensor
>
b_qzeros
,
std
::
optional
<
torch
::
Tensor
>
topk_weights
,
torch
::
Tensor
sorted_token_ids
,
torch
::
Tensor
expert_ids
,
torch
::
Tensor
num_tokens_post_pad
,
int64_t
top_k
,
int64_t
BLOCK_SIZE_M
,
int64_t
BLOCK_SIZE_N
,
int64_t
BLOCK_SIZE_K
,
int64_t
bit
);
#endif
\ No newline at end of file
Prev
1
2
3
4
5
6
7
8
9
…
27
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment