Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
469e903b
Commit
469e903b
authored
Mar 28, 2025
by
zhuwenwen
Browse files
Merge tag 'v0.8.2' into v0.8.2-dev
parents
389ebcf7
25f560a6
Changes
535
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
896 additions
and
198 deletions
+896
-198
csrc/core/math.hpp
csrc/core/math.hpp
+0
-5
csrc/cpu/attention.cpp
csrc/cpu/attention.cpp
+2
-2
csrc/cpu/cache.cpp
csrc/cpu/cache.cpp
+21
-17
csrc/cpu/cpu_types.hpp
csrc/cpu/cpu_types.hpp
+3
-0
csrc/cpu/cpu_types_arm.hpp
csrc/cpu/cpu_types_arm.hpp
+4
-0
csrc/cpu/cpu_types_vxe.hpp
csrc/cpu/cpu_types_vxe.hpp
+480
-0
csrc/cpu/cpu_types_x86.hpp
csrc/cpu/cpu_types_x86.hpp
+9
-0
csrc/cpu/pos_encoding.cpp
csrc/cpu/pos_encoding.cpp
+1
-1
csrc/cpu/quant.cpp
csrc/cpu/quant.cpp
+1
-1
csrc/cuda_utils.h
csrc/cuda_utils.h
+18
-4
csrc/custom_all_reduce.cu
csrc/custom_all_reduce.cu
+41
-0
csrc/custom_all_reduce.cuh
csrc/custom_all_reduce.cuh
+140
-50
csrc/custom_all_reduce_test.cu
csrc/custom_all_reduce_test.cu
+46
-29
csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c2x.hpp
csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c2x.hpp
+14
-12
csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp
csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp
+38
-38
csrc/cutlass_extensions/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8_blockwise_scaling.hpp
...mma_tma_gmma_ss_warpspecialized_fp8_blockwise_scaling.hpp
+1
-1
csrc/cutlass_extensions/vllm_cutlass_library_extension.py
csrc/cutlass_extensions/vllm_cutlass_library_extension.py
+7
-7
csrc/dispatch_utils.h
csrc/dispatch_utils.h
+26
-6
csrc/layernorm_quant_kernels.cu
csrc/layernorm_quant_kernels.cu
+33
-25
csrc/moe/moe_ops.h
csrc/moe/moe_ops.h
+11
-0
No files found.
Too many changes to show.
To preserve performance only
535 of 535+
files are displayed.
Plain diff
Email patch
csrc/core/math.hpp
View file @
469e903b
...
...
@@ -7,8 +7,3 @@ inline constexpr uint32_t next_pow_2(uint32_t const num) {
if
(
num
<=
1
)
return
num
;
return
1
<<
(
CHAR_BIT
*
sizeof
(
num
)
-
__builtin_clz
(
num
-
1
));
}
template
<
typename
T
>
inline
constexpr
std
::
enable_if_t
<
std
::
is_integral_v
<
T
>
,
T
>
ceil_div
(
T
a
,
T
b
)
{
return
(
a
+
b
-
1
)
/
b
;
}
\ No newline at end of file
csrc/cpu/attention.cpp
View file @
469e903b
...
...
@@ -24,8 +24,8 @@ struct KernelVecType<float> {
template
<
>
struct
KernelVecType
<
c10
::
Half
>
{
#ifdef
__powerpc64__
// Power architecture-specific vector types
#if
def
ined(
__powerpc64__
) || defined(__s390x__)
// Power
and s390x
architecture-specific vector types
using
q_load_vec_type
=
vec_op
::
FP32Vec8
;
using
k_load_vec_type
=
vec_op
::
FP32Vec16
;
using
v_load_vec_type
=
vec_op
::
FP32Vec16
;
...
...
csrc/cpu/cache.cpp
View file @
469e903b
...
...
@@ -3,6 +3,12 @@
#include "cpu_types.hpp"
#if defined(__x86_64__)
#define DISPATCH_MACRO VLLM_DISPATCH_FLOATING_TYPES_WITH_E5M2
#else
#define DISPATCH_MACRO VLLM_DISPATCH_FLOATING_TYPES
#endif
namespace
{
template
<
typename
scalar_t
>
void
copy_blocks_cpu_impl
(
std
::
vector
<
torch
::
Tensor
>
const
&
key_caches
,
...
...
@@ -95,13 +101,12 @@ void copy_blocks(std::vector<torch::Tensor> const& key_caches,
}
const
int
element_num_per_block
=
key_caches
[
0
][
0
].
numel
();
VLLM_DISPATCH_FLOATING_TYPES
(
key_caches
[
0
].
scalar_type
(),
"copy_blocks_cpu_impl"
,
[
&
]
{
CPU_KERNEL_GUARD_IN
(
copy_blocks_cpu_impl
)
copy_blocks_cpu_impl
<
scalar_t
>
(
key_caches
,
value_caches
,
block_mapping
,
element_num_per_block
,
num_layers
);
CPU_KERNEL_GUARD_OUT
(
copy_blocks_cpu_impl
)
});
DISPATCH_MACRO
(
key_caches
[
0
].
scalar_type
(),
"copy_blocks_cpu_impl"
,
[
&
]
{
CPU_KERNEL_GUARD_IN
(
copy_blocks_cpu_impl
)
copy_blocks_cpu_impl
<
scalar_t
>
(
key_caches
,
value_caches
,
block_mapping
,
element_num_per_block
,
num_layers
);
CPU_KERNEL_GUARD_OUT
(
copy_blocks_cpu_impl
)
});
}
void
reshape_and_cache
(
torch
::
Tensor
&
key
,
torch
::
Tensor
&
value
,
...
...
@@ -118,16 +123,15 @@ void reshape_and_cache(torch::Tensor& key, torch::Tensor& value,
int
key_stride
=
key
.
stride
(
0
);
int
value_stride
=
value
.
stride
(
0
);
VLLM_DISPATCH_FLOATING_TYPES
(
key
.
scalar_type
(),
"reshape_and_cache_cpu_impl"
,
[
&
]
{
CPU_KERNEL_GUARD_IN
(
reshape_and_cache_cpu_impl
)
reshape_and_cache_cpu_impl
<
scalar_t
>
(
key
.
data_ptr
<
scalar_t
>
(),
value
.
data_ptr
<
scalar_t
>
(),
key_cache
.
data_ptr
<
scalar_t
>
(),
value_cache
.
data_ptr
<
scalar_t
>
(),
slot_mapping
.
data_ptr
<
int64_t
>
(),
num_tokens
,
key_stride
,
value_stride
,
num_heads
,
head_size
,
block_size
,
x
);
CPU_KERNEL_GUARD_OUT
(
reshape_and_cache_cpu_impl
)
});
DISPATCH_MACRO
(
key
.
scalar_type
(),
"reshape_and_cache_cpu_impl"
,
[
&
]
{
CPU_KERNEL_GUARD_IN
(
reshape_and_cache_cpu_impl
)
reshape_and_cache_cpu_impl
<
scalar_t
>
(
key
.
data_ptr
<
scalar_t
>
(),
value
.
data_ptr
<
scalar_t
>
(),
key_cache
.
data_ptr
<
scalar_t
>
(),
value_cache
.
data_ptr
<
scalar_t
>
(),
slot_mapping
.
data_ptr
<
int64_t
>
(),
num_tokens
,
key_stride
,
value_stride
,
num_heads
,
head_size
,
block_size
,
x
);
CPU_KERNEL_GUARD_OUT
(
reshape_and_cache_cpu_impl
)
});
}
void
swap_blocks
(
torch
::
Tensor
&
src
,
torch
::
Tensor
&
dst
,
...
...
csrc/cpu/cpu_types.hpp
View file @
469e903b
...
...
@@ -7,6 +7,9 @@
#elif defined(__POWER9_VECTOR__)
// ppc implementation
#include "cpu_types_vsx.hpp"
#elif defined(__s390x__)
// s390 implementation
#include "cpu_types_vxe.hpp"
#elif defined(__aarch64__)
// arm implementation
#include "cpu_types_arm.hpp"
...
...
csrc/cpu/cpu_types_arm.hpp
View file @
469e903b
...
...
@@ -2,6 +2,10 @@
#include <torch/all.h>
#include <cmath>
#if defined(__APPLE__)
#include "omp.h"
#endif
namespace
vec_op
{
#ifdef ARM_BF16_SUPPORT
...
...
csrc/cpu/cpu_types_vxe.hpp
0 → 100644
View file @
469e903b
#ifndef CPU_TYPES_VXE_HPP
#define CPU_TYPES_VXE_HPP
#include <vecintrin.h>
#include <cmath>
#include <torch/all.h>
namespace
vec_op
{
#define vec_neg(a) (-(a))
#define vec_add(a, b) ((a) + (b))
#define vec_sub(a, b) ((a) - (b))
#define vec_mul(a, b) ((a) * (b))
#define vec_div(a, b) ((a) / (b))
#define vec_sr(a, b) ((a) >> (b)) // Vector Shift Right Algebaic
#define vec_sl(a, b) ((a) << (b)) // Vector Shift Left
// FIXME: FP16 is not fully supported in Torch-CPU
#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__)
#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))
#ifndef CPU_OP_GUARD
#define CPU_KERNEL_GUARD_IN(NAME)
#define CPU_KERNEL_GUARD_OUT(NAME)
#else
#define CPU_KERNEL_GUARD_IN(NAME) \
std::cout << #NAME << " invoked." << std::endl;
#define CPU_KERNEL_GUARD_OUT(NAME) \
std::cout << #NAME << " exit." << std::endl;
#endif
#define FORCE_INLINE __attribute__((always_inline)) inline
namespace
{
template
<
typename
T
,
T
...
indexes
,
typename
F
>
constexpr
void
unroll_loop_item
(
std
::
integer_sequence
<
T
,
indexes
...
>
,
F
&&
f
)
{
(
f
(
std
::
integral_constant
<
T
,
indexes
>
{}),
...);
}
};
// namespace
template
<
typename
T
,
T
count
,
typename
F
,
typename
=
std
::
enable_if_t
<
std
::
is_invocable_v
<
F
,
T
>
>>
constexpr
void
unroll_loop
(
F
&&
f
)
{
unroll_loop_item
(
std
::
make_integer_sequence
<
T
,
count
>
{},
std
::
forward
<
F
>
(
f
));
}
template
<
typename
T
>
struct
Vec
{
constexpr
static
int
get_elem_num
()
{
return
T
::
VEC_ELEM_NUM
;
}
};
typedef
struct
ss16x8x2_t
{
__vector
signed
short
val
[
2
];
}
ss16x8x2_t
;
typedef
struct
ss16x8x4_t
{
__vector
signed
short
val
[
4
];
}
ss16x8x4_t
;
typedef
struct
f32x4x2_t
{
__vector
float
val
[
2
];
}
f32x4x2_t
;
typedef
struct
f32x4x4_t
{
__vector
float
val
[
4
];
}
f32x4x4_t
;
struct
FP32Vec8
;
struct
FP32Vec16
;
struct
BF16Vec8
:
public
Vec
<
BF16Vec8
>
{
constexpr
static
int
VEC_ELEM_NUM
=
8
;
__vector
signed
short
reg
;
explicit
BF16Vec8
(
const
void
*
ptr
)
:
reg
(
*
(
__vector
signed
short
*
)
ptr
)
{}
explicit
BF16Vec8
(
const
FP32Vec8
&
);
void
save
(
void
*
ptr
)
const
{
*
reinterpret_cast
<
__vector
signed
short
*>
(
ptr
)
=
reg
;
}
};
struct
BF16Vec16
:
public
Vec
<
BF16Vec16
>
{
constexpr
static
int
VEC_ELEM_NUM
=
16
;
ss16x8x2_t
reg
;
explicit
BF16Vec16
(
const
void
*
ptr
)
{
// Load 256 bits in two parts
reg
.
val
[
0
]
=
(
__vector
signed
short
)
vec_xl
(
0
,
(
signed
short
*
)
ptr
);
reg
.
val
[
1
]
=
(
__vector
signed
short
)
vec_xl
(
16
,
(
signed
short
*
)
ptr
);
}
explicit
BF16Vec16
(
const
FP32Vec16
&
);
void
save
(
void
*
ptr
)
const
{
// Save 256 bits in two parts
vec_xst
(
reg
.
val
[
0
],
0
,
(
signed
short
*
)
ptr
);
vec_xst
(
reg
.
val
[
1
],
16
,
(
signed
short
*
)
ptr
);
}
};
const
static
__vector
signed
short
zero
=
vec_splats
((
signed
short
)
0
);
struct
BF16Vec32
:
public
Vec
<
BF16Vec32
>
{
constexpr
static
int
VEC_ELEM_NUM
=
32
;
ss16x8x4_t
reg
;
explicit
BF16Vec32
(
const
void
*
ptr
)
:
reg
(
*
reinterpret_cast
<
const
ss16x8x4_t
*>
(
ptr
))
{}
explicit
BF16Vec32
(
ss16x8x4_t
data
)
:
reg
(
data
)
{}
explicit
BF16Vec32
(
const
BF16Vec8
&
vec8_data
)
:
reg
({
vec8_data
.
reg
,
vec8_data
.
reg
,
vec8_data
.
reg
,
vec8_data
.
reg
})
{}
void
save
(
void
*
ptr
)
const
{
*
reinterpret_cast
<
ss16x8x4_t
*>
(
ptr
)
=
reg
;
}
};
struct
FP32Vec4
:
public
Vec
<
FP32Vec4
>
{
constexpr
static
int
VEC_ELEM_NUM
=
4
;
union
AliasReg
{
__vector
float
reg
;
float
values
[
VEC_ELEM_NUM
];
};
__vector
float
reg
;
explicit
FP32Vec4
(
float
v
)
:
reg
(
vec_splats
(
v
))
{}
explicit
FP32Vec4
()
:
reg
(
vec_splats
(
0.0
f
))
{}
explicit
FP32Vec4
(
const
float
*
ptr
)
:
reg
(
vec_xl
(
0
,
ptr
))
{}
explicit
FP32Vec4
(
__vector
float
data
)
:
reg
(
data
)
{}
explicit
FP32Vec4
(
const
FP32Vec4
&
data
)
:
reg
(
data
.
reg
)
{}
};
struct
FP32Vec8
:
public
Vec
<
FP32Vec8
>
{
constexpr
static
int
VEC_ELEM_NUM
=
8
;
union
AliasReg
{
f32x4x2_t
reg
;
float
values
[
VEC_ELEM_NUM
];
};
f32x4x2_t
reg
;
explicit
FP32Vec8
(
float
v
)
{
reg
.
val
[
0
]
=
vec_splats
(
v
);
reg
.
val
[
1
]
=
vec_splats
(
v
);
}
explicit
FP32Vec8
()
{
reg
.
val
[
0
]
=
vec_splats
(
0.0
f
);
reg
.
val
[
1
]
=
vec_splats
(
0.0
f
);
}
explicit
FP32Vec8
(
const
float
*
ptr
)
{
reg
.
val
[
0
]
=
vec_xl
(
0
,
ptr
);
reg
.
val
[
1
]
=
vec_xl
(
16
,
ptr
);
}
explicit
FP32Vec8
(
f32x4x2_t
data
)
:
reg
(
data
)
{}
explicit
FP32Vec8
(
const
FP32Vec8
&
data
)
{
reg
.
val
[
0
]
=
data
.
reg
.
val
[
0
];
reg
.
val
[
1
]
=
data
.
reg
.
val
[
1
];
}
explicit
FP32Vec8
(
const
BF16Vec8
&
v
)
{
reg
.
val
[
0
]
=
(
__vector
float
)
vec_mergeh
(
zero
,
v
.
reg
);
reg
.
val
[
1
]
=
(
__vector
float
)
vec_mergel
(
zero
,
v
.
reg
);
}
float
reduce_sum
()
const
{
AliasReg
ar
;
ar
.
reg
=
reg
;
float
result
=
0
;
unroll_loop
<
int
,
VEC_ELEM_NUM
>
(
[
&
result
,
&
ar
](
int
i
)
{
result
+=
ar
.
values
[
i
];
});
return
result
;
}
FP32Vec8
exp
()
const
{
// TODO: Vectorize this
AliasReg
ar
;
ar
.
reg
=
reg
;
f32x4x4_t
ret
;
ret
.
val
[
0
][
0
]
=
std
::
exp
(
ar
.
values
[
0
]);
ret
.
val
[
0
][
1
]
=
std
::
exp
(
ar
.
values
[
1
]);
ret
.
val
[
0
][
2
]
=
std
::
exp
(
ar
.
values
[
2
]);
ret
.
val
[
0
][
3
]
=
std
::
exp
(
ar
.
values
[
3
]);
ret
.
val
[
1
][
0
]
=
std
::
exp
(
ar
.
values
[
4
]);
ret
.
val
[
1
][
1
]
=
std
::
exp
(
ar
.
values
[
5
]);
ret
.
val
[
1
][
2
]
=
std
::
exp
(
ar
.
values
[
6
]);
ret
.
val
[
1
][
3
]
=
std
::
exp
(
ar
.
values
[
7
]);
return
FP32Vec8
(
f32x4x2_t
({
ret
.
val
[
0
],
ret
.
val
[
1
]}));
}
FP32Vec8
tanh
()
const
{
// TODO: Vectorize this
AliasReg
ar
;
ar
.
reg
=
reg
;
f32x4x4_t
ret
;
ret
.
val
[
0
][
0
]
=
std
::
tanh
(
ar
.
values
[
0
]);
ret
.
val
[
0
][
1
]
=
std
::
tanh
(
ar
.
values
[
1
]);
ret
.
val
[
0
][
2
]
=
std
::
tanh
(
ar
.
values
[
2
]);
ret
.
val
[
0
][
3
]
=
std
::
tanh
(
ar
.
values
[
3
]);
ret
.
val
[
1
][
0
]
=
std
::
tanh
(
ar
.
values
[
4
]);
ret
.
val
[
1
][
1
]
=
std
::
tanh
(
ar
.
values
[
5
]);
ret
.
val
[
1
][
2
]
=
std
::
tanh
(
ar
.
values
[
6
]);
ret
.
val
[
1
][
3
]
=
std
::
tanh
(
ar
.
values
[
7
]);
return
FP32Vec8
(
f32x4x2_t
({
ret
.
val
[
0
],
ret
.
val
[
1
]}));
}
FP32Vec8
er
()
const
{
// TODO: Vectorize this
AliasReg
ar
;
ar
.
reg
=
reg
;
f32x4x4_t
ret
;
ret
.
val
[
0
][
0
]
=
std
::
erf
(
ar
.
values
[
0
]);
ret
.
val
[
0
][
1
]
=
std
::
erf
(
ar
.
values
[
1
]);
ret
.
val
[
0
][
2
]
=
std
::
erf
(
ar
.
values
[
2
]);
ret
.
val
[
0
][
3
]
=
std
::
erf
(
ar
.
values
[
3
]);
ret
.
val
[
1
][
0
]
=
std
::
erf
(
ar
.
values
[
4
]);
ret
.
val
[
1
][
1
]
=
std
::
erf
(
ar
.
values
[
5
]);
ret
.
val
[
1
][
2
]
=
std
::
erf
(
ar
.
values
[
6
]);
ret
.
val
[
1
][
3
]
=
std
::
erf
(
ar
.
values
[
7
]);
return
FP32Vec8
(
f32x4x2_t
({
ret
.
val
[
0
],
ret
.
val
[
1
]}));
}
FP32Vec8
operator
*
(
const
FP32Vec8
&
b
)
const
{
return
FP32Vec8
(
{
vec_mul
(
reg
.
val
[
0
],
b
.
reg
.
val
[
0
]),
vec_mul
(
reg
.
val
[
1
],
b
.
reg
.
val
[
1
])});
}
FP32Vec8
operator
+
(
const
FP32Vec8
&
b
)
const
{
return
FP32Vec8
(
{
vec_add
(
reg
.
val
[
0
],
b
.
reg
.
val
[
0
]),
vec_add
(
reg
.
val
[
1
],
b
.
reg
.
val
[
1
])});
}
FP32Vec8
operator
-
(
const
FP32Vec8
&
b
)
const
{
return
FP32Vec8
(
{
vec_sub
(
reg
.
val
[
0
],
b
.
reg
.
val
[
0
]),
vec_sub
(
reg
.
val
[
1
],
b
.
reg
.
val
[
1
])});
}
FP32Vec8
operator
/
(
const
FP32Vec8
&
b
)
const
{
return
FP32Vec8
(
{
vec_div
(
reg
.
val
[
0
],
b
.
reg
.
val
[
0
]),
vec_div
(
reg
.
val
[
1
],
b
.
reg
.
val
[
1
])});
}
void
save
(
float
*
ptr
)
const
{
vec_xst
(
reg
.
val
[
0
],
0
,
ptr
);
vec_xst
(
reg
.
val
[
1
],
16
,
ptr
);
}
};
struct
FP32Vec16
:
public
Vec
<
FP32Vec16
>
{
constexpr
static
int
VEC_ELEM_NUM
=
16
;
union
AliasReg
{
f32x4x4_t
reg
;
float
values
[
VEC_ELEM_NUM
];
};
f32x4x4_t
reg
;
explicit
FP32Vec16
(
float
v
)
{
reg
.
val
[
0
]
=
vec_splats
(
v
);
reg
.
val
[
1
]
=
vec_splats
(
v
);
reg
.
val
[
2
]
=
vec_splats
(
v
);
reg
.
val
[
3
]
=
vec_splats
(
v
);
}
explicit
FP32Vec16
()
{
reg
.
val
[
0
]
=
vec_splats
(
0.0
f
);
reg
.
val
[
1
]
=
vec_splats
(
0.0
f
);
reg
.
val
[
2
]
=
vec_splats
(
0.0
f
);
reg
.
val
[
3
]
=
vec_splats
(
0.0
f
);
}
explicit
FP32Vec16
(
const
float
*
ptr
)
{
reg
.
val
[
0
]
=
vec_xl
(
0
,
ptr
);
reg
.
val
[
1
]
=
vec_xl
(
16
,
ptr
);
reg
.
val
[
2
]
=
vec_xl
(
32
,
ptr
);
reg
.
val
[
3
]
=
vec_xl
(
48
,
ptr
);
}
explicit
FP32Vec16
(
f32x4x4_t
data
)
:
reg
(
data
)
{}
explicit
FP32Vec16
(
const
FP32Vec16
&
data
)
{
reg
.
val
[
0
]
=
data
.
reg
.
val
[
0
];
reg
.
val
[
1
]
=
data
.
reg
.
val
[
1
];
reg
.
val
[
2
]
=
data
.
reg
.
val
[
2
];
reg
.
val
[
3
]
=
data
.
reg
.
val
[
3
];
}
explicit
FP32Vec16
(
const
FP32Vec4
&
data
)
{
reg
.
val
[
0
]
=
data
.
reg
;
reg
.
val
[
1
]
=
data
.
reg
;
reg
.
val
[
2
]
=
data
.
reg
;
reg
.
val
[
3
]
=
data
.
reg
;
}
explicit
FP32Vec16
(
const
FP32Vec8
&
data
)
{
reg
.
val
[
0
]
=
data
.
reg
.
val
[
0
];
reg
.
val
[
1
]
=
data
.
reg
.
val
[
1
];
reg
.
val
[
2
]
=
data
.
reg
.
val
[
0
];
reg
.
val
[
3
]
=
data
.
reg
.
val
[
1
];
}
explicit
FP32Vec16
(
const
BF16Vec16
&
v
)
{
reg
.
val
[
0
]
=
(
__vector
float
)
vec_mergeh
(
zero
,
v
.
reg
.
val
[
0
]);
reg
.
val
[
1
]
=
(
__vector
float
)
vec_mergel
(
zero
,
v
.
reg
.
val
[
0
]);
reg
.
val
[
2
]
=
(
__vector
float
)
vec_mergeh
(
zero
,
v
.
reg
.
val
[
1
]);
reg
.
val
[
3
]
=
(
__vector
float
)
vec_mergel
(
zero
,
v
.
reg
.
val
[
1
]);
}
explicit
FP32Vec16
(
const
BF16Vec8
&
v
)
:
FP32Vec16
(
FP32Vec8
(
v
))
{}
FP32Vec16
operator
*
(
const
FP32Vec16
&
b
)
const
{
return
FP32Vec16
(
f32x4x4_t
({
vec_mul
(
reg
.
val
[
0
],
b
.
reg
.
val
[
0
]),
vec_mul
(
reg
.
val
[
1
],
b
.
reg
.
val
[
1
]),
vec_mul
(
reg
.
val
[
2
],
b
.
reg
.
val
[
2
]),
vec_mul
(
reg
.
val
[
3
],
b
.
reg
.
val
[
3
])}));
}
FP32Vec16
operator
+
(
const
FP32Vec16
&
b
)
const
{
return
FP32Vec16
(
f32x4x4_t
({
vec_add
(
reg
.
val
[
0
],
b
.
reg
.
val
[
0
]),
vec_add
(
reg
.
val
[
1
],
b
.
reg
.
val
[
1
]),
vec_add
(
reg
.
val
[
2
],
b
.
reg
.
val
[
2
]),
vec_add
(
reg
.
val
[
3
],
b
.
reg
.
val
[
3
])}));
}
FP32Vec16
operator
-
(
const
FP32Vec16
&
b
)
const
{
return
FP32Vec16
(
f32x4x4_t
({
vec_sub
(
reg
.
val
[
0
],
b
.
reg
.
val
[
0
]),
vec_sub
(
reg
.
val
[
1
],
b
.
reg
.
val
[
1
]),
vec_sub
(
reg
.
val
[
2
],
b
.
reg
.
val
[
2
]),
vec_sub
(
reg
.
val
[
3
],
b
.
reg
.
val
[
3
])}));
}
FP32Vec16
operator
/
(
const
FP32Vec16
&
b
)
const
{
return
FP32Vec16
(
f32x4x4_t
({
vec_div
(
reg
.
val
[
0
],
b
.
reg
.
val
[
0
]),
vec_div
(
reg
.
val
[
1
],
b
.
reg
.
val
[
1
]),
vec_div
(
reg
.
val
[
2
],
b
.
reg
.
val
[
2
]),
vec_div
(
reg
.
val
[
3
],
b
.
reg
.
val
[
3
])}));
}
float
reduce_sum
()
const
{
AliasReg
ar
;
ar
.
reg
=
reg
;
float
result
=
0
;
unroll_loop
<
int
,
VEC_ELEM_NUM
>
(
[
&
result
,
&
ar
](
int
i
)
{
result
+=
ar
.
values
[
i
];
});
return
result
;
}
template
<
int
group_size
>
float
reduce_sub_sum
(
int
idx
)
{
static_assert
(
VEC_ELEM_NUM
%
group_size
==
0
);
AliasReg
ar
;
ar
.
reg
=
reg
;
float
result
=
0
;
const
int
start
=
idx
*
group_size
;
unroll_loop
<
int
,
group_size
>
(
[
&
result
,
&
start
,
ar
](
int
i
)
{
result
+=
ar
.
values
[
start
+
i
];
});
return
result
;
}
void
save
(
float
*
ptr
)
const
{
vec_xst
(
reg
.
val
[
0
],
0
,
ptr
);
vec_xst
(
reg
.
val
[
1
],
16
,
ptr
);
vec_xst
(
reg
.
val
[
2
],
32
,
ptr
);
vec_xst
(
reg
.
val
[
3
],
48
,
ptr
);
}
};
template
<
typename
T
>
struct
VecType
{
using
vec_type
=
void
;
};
template
<
typename
T
>
using
vec_t
=
typename
VecType
<
T
>::
vec_type
;
template
<
>
struct
VecType
<
float
>
{
using
vec_type
=
FP32Vec8
;
};
template
<
>
struct
VecType
<
c10
::
BFloat16
>
{
using
vec_type
=
BF16Vec8
;
};
template
<
typename
T
>
void
storeFP32
(
float
v
,
T
*
ptr
)
{
*
ptr
=
v
;
}
inline
void
fma
(
FP32Vec16
&
acc
,
FP32Vec16
&
a
,
FP32Vec16
&
b
)
{
acc
=
acc
+
a
*
b
;
}
namespace
c10
{
struct
BFloat16
{
uint16_t
value
;
// Assume BFloat16 is defined as a struct containing a 16-bit
// value.
};
}
// namespace c10
template
<
>
inline
void
storeFP32
<
c10
::
BFloat16
>
(
float
v
,
c10
::
BFloat16
*
ptr
)
{
c10
::
BFloat16
__attribute__
((
__may_alias__
))
*
v_ptr
=
reinterpret_cast
<
c10
::
BFloat16
*>
(
&
v
);
*
ptr
=
*
(
v_ptr
+
1
);
}
#ifndef __VEC_CLASS_FP_NAN
#define __VEC_CLASS_FP_NAN (1 << 6)
#endif
const
static
__vector
unsigned
char
omask
=
{
2
,
3
,
6
,
7
,
10
,
11
,
14
,
15
,
18
,
19
,
22
,
23
,
26
,
27
,
30
,
31
};
const
static
__vector
unsigned
int
bias
=
{
0x00007fff
,
0x00007fff
,
0x00007fff
,
0x00007fff
};
const
static
__vector
unsigned
int
nan
=
{
0x7fc00000
,
0x7fc00000
,
0x7fc00000
,
0x7fc00000
};
const
static
__vector
unsigned
int
sh16
=
{
16
,
16
,
16
,
16
};
const
static
__vector
unsigned
int
one
=
{
1
,
1
,
1
,
1
};
inline
BF16Vec8
::
BF16Vec8
(
const
FP32Vec8
&
v
)
{
__vector
unsigned
int
inp0
=
(
__vector
unsigned
int
)(
v
.
reg
.
val
[
0
]);
__vector
unsigned
int
inp1
=
(
__vector
unsigned
int
)(
v
.
reg
.
val
[
1
]);
int
cc
;
__vector
__bool
int
sel0
=
vec_fp_test_data_class
(
v
.
reg
.
val
[
0
],
__VEC_CLASS_FP_NAN
,
&
cc
);
__vector
__bool
int
sel1
=
vec_fp_test_data_class
(
v
.
reg
.
val
[
1
],
__VEC_CLASS_FP_NAN
,
&
cc
);
inp0
=
vec_sel
(
inp0
,
nan
,
sel0
)
>>
sh16
;
inp1
=
vec_sel
(
inp1
,
nan
,
sel1
)
>>
sh16
;
reg
=
(
__vector
signed
short
)
vec_perm
(
inp0
,
inp1
,
omask
);
}
inline
BF16Vec16
::
BF16Vec16
(
const
FP32Vec16
&
v
)
{
__vector
unsigned
int
inp0
=
(
__vector
unsigned
int
)(
v
.
reg
.
val
[
0
]);
__vector
unsigned
int
inp1
=
(
__vector
unsigned
int
)(
v
.
reg
.
val
[
1
]);
__vector
unsigned
int
inp2
=
(
__vector
unsigned
int
)(
v
.
reg
.
val
[
2
]);
__vector
unsigned
int
inp3
=
(
__vector
unsigned
int
)(
v
.
reg
.
val
[
3
]);
int
cc
;
__vector
__bool
int
sel0
=
vec_fp_test_data_class
(
v
.
reg
.
val
[
0
],
__VEC_CLASS_FP_NAN
,
&
cc
);
__vector
__bool
int
sel1
=
vec_fp_test_data_class
(
v
.
reg
.
val
[
1
],
__VEC_CLASS_FP_NAN
,
&
cc
);
__vector
__bool
int
sel2
=
vec_fp_test_data_class
(
v
.
reg
.
val
[
2
],
__VEC_CLASS_FP_NAN
,
&
cc
);
__vector
__bool
int
sel3
=
vec_fp_test_data_class
(
v
.
reg
.
val
[
3
],
__VEC_CLASS_FP_NAN
,
&
cc
);
inp0
=
vec_sel
(
inp0
,
nan
,
sel0
)
>>
sh16
;
inp1
=
vec_sel
(
inp1
,
nan
,
sel1
)
>>
sh16
;
inp2
=
vec_sel
(
inp2
,
nan
,
sel2
)
>>
sh16
;
inp3
=
vec_sel
(
inp3
,
nan
,
sel3
)
>>
sh16
;
reg
.
val
[
0
]
=
(
__vector
signed
short
)
vec_perm
(
inp0
,
inp1
,
omask
);
reg
.
val
[
1
]
=
(
__vector
signed
short
)
vec_perm
(
inp2
,
inp3
,
omask
);
}
inline
void
prefetch
(
const
void
*
addr
)
{
void
__dcbt
(
const
void
*
addr
);
}
};
// namespace vec_op
#endif
\ No newline at end of file
csrc/cpu/cpu_types_x86.hpp
View file @
469e903b
...
...
@@ -16,9 +16,18 @@ namespace vec_op {
AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__)
#define VLLM_DISPATCH_CASE_FLOATING_TYPES_FP8(...) \
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Float8_e5m2, __VA_ARGS__)
#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))
#define VLLM_DISPATCH_FLOATING_TYPES_WITH_E5M2(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH(TYPE, NAME, \
VLLM_DISPATCH_CASE_FLOATING_TYPES_FP8(__VA_ARGS__))
#ifndef CPU_OP_GUARD
#define CPU_KERNEL_GUARD_IN(NAME)
#define CPU_KERNEL_GUARD_OUT(NAME)
...
...
csrc/cpu/pos_encoding.cpp
View file @
469e903b
...
...
@@ -170,7 +170,7 @@ void rotary_embedding_gptj_impl(
void
rotary_embedding
(
torch
::
Tensor
&
positions
,
torch
::
Tensor
&
query
,
torch
::
Tensor
&
key
,
int64_t
head_size
,
torch
::
Tensor
&
cos_sin_cache
,
bool
is_neox
)
{
int
num_tokens
=
query
.
numel
()
/
query
.
size
(
-
1
);
int
num_tokens
=
positions
.
numel
(
);
int
rot_dim
=
cos_sin_cache
.
size
(
1
);
int
num_heads
=
query
.
size
(
-
1
)
/
head_size
;
int
num_kv_heads
=
key
.
size
(
-
1
)
/
head_size
;
...
...
csrc/cpu/quant.cpp
View file @
469e903b
...
...
@@ -25,7 +25,7 @@ struct KernelVecType<c10::BFloat16> {
template
<
>
struct
KernelVecType
<
c10
::
Half
>
{
#ifdef
__powerpc64__
#if
def
ined(
__powerpc64__
) || defined(__s390x__)
// Power architecture-specific vector type
using
load_vec_type
=
vec_op
::
FP32Vec16
;
#else
...
...
csrc/cuda_utils.h
View file @
469e903b
...
...
@@ -2,10 +2,14 @@
#include <stdio.h>
#if defined(__CUDACC__) || defined(_NVHPC_CUDA)
#define HOST_DEVICE_INLINE __forceinline__ __host__ __device__
#define DEVICE_INLINE __forceinline__ __device__
#define HOST_INLINE __forceinline__ __host__
#if defined(__HIPCC__)
#define HOST_DEVICE_INLINE __host__ __device__
#define DEVICE_INLINE __device__
#define HOST_INLINE __host__
#elif defined(__CUDACC__) || defined(_NVHPC_CUDA)
#define HOST_DEVICE_INLINE __host__ __device__ __forceinline__
#define DEVICE_INLINE __device__ __forceinline__
#define HOST_INLINE __host__ __forceinline__
#else
#define HOST_DEVICE_INLINE inline
#define DEVICE_INLINE inline
...
...
@@ -25,3 +29,13 @@
int64_t
get_device_attribute
(
int64_t
attribute
,
int64_t
device_id
);
int64_t
get_max_shared_memory_per_block_device_attribute
(
int64_t
device_id
);
namespace
cuda_utils
{
template
<
typename
T
>
HOST_DEVICE_INLINE
constexpr
std
::
enable_if_t
<
std
::
is_integral_v
<
T
>
,
T
>
ceil_div
(
T
a
,
T
b
)
{
return
(
a
+
b
-
1
)
/
b
;
}
};
// namespace cuda_utils
\ No newline at end of file
csrc/custom_all_reduce.cu
View file @
469e903b
...
...
@@ -142,3 +142,44 @@ void register_graph_buffers(fptr_t _fa,
bytes
.
reserve
(
handles
.
size
());
fa
->
register_graph_buffers
(
bytes
,
offsets
);
}
std
::
tuple
<
fptr_t
,
torch
::
Tensor
>
allocate_shared_buffer_and_handle
(
int64_t
size
)
{
auto
device_index
=
c10
::
cuda
::
current_device
();
at
::
DeviceGuard
device_guard
(
at
::
Device
(
at
::
DeviceType
::
CUDA
,
device_index
));
void
*
buffer
;
cudaStreamCaptureMode
mode
=
cudaStreamCaptureModeRelaxed
;
auto
stream
=
c10
::
cuda
::
getCurrentCUDAStream
().
stream
();
AT_CUDA_CHECK
(
cudaThreadExchangeStreamCaptureMode
(
&
mode
));
#if defined(USE_ROCM)
// data buffers need to be "uncached" for signal on MI200
AT_CUDA_CHECK
(
hipExtMallocWithFlags
((
void
**
)
&
buffer
,
size
,
hipDeviceMallocUncached
));
#else
AT_CUDA_CHECK
(
cudaMalloc
((
void
**
)
&
buffer
,
size
));
#endif
AT_CUDA_CHECK
(
cudaMemsetAsync
(
buffer
,
0
,
size
,
stream
));
AT_CUDA_CHECK
(
cudaStreamSynchronize
(
stream
));
AT_CUDA_CHECK
(
cudaThreadExchangeStreamCaptureMode
(
&
mode
));
auto
options
=
torch
::
TensorOptions
().
dtype
(
torch
::
kUInt8
).
device
(
torch
::
kCPU
);
auto
handle
=
torch
::
empty
({
static_cast
<
int64_t
>
(
sizeof
(
cudaIpcMemHandle_t
))},
options
);
AT_CUDA_CHECK
(
cudaIpcGetMemHandle
((
cudaIpcMemHandle_t
*
)
handle
.
data_ptr
(),
buffer
));
return
std
::
make_tuple
(
reinterpret_cast
<
fptr_t
>
(
buffer
),
handle
);
}
fptr_t
open_mem_handle
(
torch
::
Tensor
&
mem_handle
)
{
void
*
ipc_ptr
;
AT_CUDA_CHECK
(
cudaIpcOpenMemHandle
(
(
void
**
)
&
ipc_ptr
,
*
((
const
cudaIpcMemHandle_t
*
)
mem_handle
.
data_ptr
()),
cudaIpcMemLazyEnablePeerAccess
));
return
reinterpret_cast
<
fptr_t
>
(
ipc_ptr
);
}
void
free_shared_buffer
(
fptr_t
buffer
)
{
AT_CUDA_CHECK
(
cudaFree
(
reinterpret_cast
<
void
*>
(
buffer
)));
}
csrc/custom_all_reduce.cuh
View file @
469e903b
...
...
@@ -5,6 +5,10 @@
#include <cuda_fp16.h>
#include <cuda_runtime.h>
#if defined(USE_ROCM)
typedef
__hip_bfloat16
nv_bfloat16
;
#endif
#include <iostream>
#include <array>
#include <limits>
...
...
@@ -12,6 +16,7 @@
#include <unordered_map>
#include <vector>
namespace
vllm
{
#define CUDACHECK(cmd) \
do { \
cudaError_t e = cmd; \
...
...
@@ -22,24 +27,37 @@
} \
} while (0)
namespace
vllm
{
// Maximal number of blocks in allreduce kernel.
constexpr
int
kMaxBlocks
=
36
;
// Default number of blocks in allreduce kernel.
#ifndef USE_ROCM
const
int
defaultBlockLimit
=
36
;
CUpointer_attribute
rangeStartAddrAttr
=
CUDA_POINTER_ATTRIBUTE_RANGE_START_ADDR
;
#else
const
int
defaultBlockLimit
=
16
;
hipPointer_attribute
rangeStartAddrAttr
=
HIP_POINTER_ATTRIBUTE_RANGE_START_ADDR
;
#endif
// Counter may overflow, but it's fine since unsigned int overflow is
// well-defined behavior.
using
FlagType
=
uint32_t
;
// Two sets of peer counters are needed for two syncs. The reason is that
// it's possible for peer GPU block to arrive at the second sync point while
// the current GPU block haven't passed the first sync point. Thus, peer GPU
// may write counter+1 while current GPU is busy waiting for counter. We use
// alternating counter array to avoid this possibility.
struct
Signal
{
alignas
(
128
)
FlagType
self_counter
[
kMaxBlocks
][
8
];
// Two sets of peer counters are needed for two syncs. The reason is that
// it's possible for peer GPU block to arrive at the second sync point while
// the current GPU block haven't passed the first sync point. Thus, peer GPU
// may write counter+1 while current GPU is busy waiting for counter. We use
// alternating counter array to avoid this possibility.
alignas
(
128
)
FlagType
peer_counter
[
2
][
kMaxBlocks
][
8
];
alignas
(
128
)
FlagType
start
[
kMaxBlocks
][
8
];
alignas
(
128
)
FlagType
end
[
kMaxBlocks
][
8
];
alignas
(
128
)
FlagType
_flag
[
kMaxBlocks
];
// incremental flags for each rank
};
struct
__align__
(
16
)
RankData
{
const
void
*
__restrict__
ptrs
[
8
];
const
void
*
ptrs
[
8
];
};
struct
__align__
(
16
)
RankSignals
{
...
...
@@ -134,27 +152,29 @@ DINLINE O downcast(array_t<float, O::size> val) {
}
}
#if !defined(USE_ROCM)
static
DINLINE
void
st_flag_release
(
FlagType
*
flag_addr
,
FlagType
flag
)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700
asm
volatile
(
"st.release.sys.global.u32 [%1], %0;"
::
"r"
(
flag
),
"l"
(
flag_addr
));
#else
#else
asm
volatile
(
"membar.sys; st.volatile.global.u32 [%1], %0;"
::
"r"
(
flag
),
"l"
(
flag_addr
));
#endif
#endif
}
static
DINLINE
FlagType
ld_flag_acquire
(
FlagType
*
flag_addr
)
{
FlagType
flag
;
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700
asm
volatile
(
"ld.acquire.sys.global.u32 %0, [%1];"
:
"=r"
(
flag
)
:
"l"
(
flag_addr
));
#else
#else
asm
volatile
(
"ld.volatile.global.u32 %0, [%1]; membar.gl;"
:
"=r"
(
flag
)
:
"l"
(
flag_addr
));
#endif
#endif
return
flag
;
}
...
...
@@ -170,37 +190,108 @@ static DINLINE FlagType ld_flag_volatile(FlagType* flag_addr) {
return
flag
;
}
// is_start: whether this is the very first synchronization barrier.
// need_fence: whether a memory fence is needed. If true, a release-acquire
// semantic is used to enforce memory access order before and after this
// barrier.
template
<
int
ngpus
,
bool
is_start
,
bool
need_fence
=
false
>
DINLINE
void
multi_gpu_barrier
(
const
RankSignals
&
sg
,
Signal
*
self_sg
,
int
rank
)
{
if
constexpr
(
!
is_start
)
__syncthreads
();
static_assert
(
!
(
is_start
&&
need_fence
));
// Start barrier shouldn't need fence.
// This function is meant to be used as the first synchronization in the all
// reduce kernel. Thus, it doesn't need to make any visibility guarantees for
// prior memory accesses. Note: volatile writes will not be reordered against
// other volatile writes.
template
<
int
ngpus
>
DINLINE
void
start_sync
(
const
RankSignals
&
sg
,
Signal
*
self_sg
,
int
rank
)
{
uint32_t
flag
=
self_sg
->
_flag
[
blockIdx
.
x
]
+
1
;
if
(
threadIdx
.
x
<
ngpus
)
{
// Increment the counter. Technically we only need one counter, but we use
// multiple per block to eliminate the need to share the counter via smem.
auto
val
=
self_sg
->
self_counter
[
blockIdx
.
x
][
threadIdx
.
x
]
+=
1
;
auto
peer_counter_ptr
=
&
sg
.
signals
[
threadIdx
.
x
]
->
start
[
blockIdx
.
x
][
rank
];
auto
self_counter_ptr
=
&
self_sg
->
start
[
blockIdx
.
x
][
threadIdx
.
x
];
// Write the expected counter value to peer and wait for correct value
// from peer.
st_flag_volatile
(
peer_counter_ptr
,
flag
);
while
(
ld_flag_volatile
(
self_counter_ptr
)
!=
flag
);
}
__syncthreads
();
// use one thread to update flag
if
(
threadIdx
.
x
==
0
)
self_sg
->
_flag
[
blockIdx
.
x
]
=
flag
;
}
// This function is meant to be used as the second or the final
// synchronization barrier in the all reduce kernel. If it's the final
// synchronization barrier, we don't need to make any visibility guarantees
// for prior memory accesses.
template
<
int
ngpus
,
bool
final_sync
=
false
>
DINLINE
void
end_sync
(
const
RankSignals
&
sg
,
Signal
*
self_sg
,
int
rank
)
{
__syncthreads
();
uint32_t
flag
=
self_sg
->
_flag
[
blockIdx
.
x
]
+
1
;
if
(
threadIdx
.
x
<
ngpus
)
{
auto
peer_counter_ptr
=
&
sg
.
signals
[
threadIdx
.
x
]
->
end
[
blockIdx
.
x
][
rank
];
auto
self_counter_ptr
=
&
self_sg
->
end
[
blockIdx
.
x
][
threadIdx
.
x
];
// Write the expected counter value to peer and wait for correct value from
// peer.
auto
peer_counter_ptr
=
&
sg
.
signals
[
threadIdx
.
x
]
->
peer_counter
[
val
%
2
][
blockIdx
.
x
][
rank
];
auto
self_counter_ptr
=
&
self_sg
->
peer_counter
[
val
%
2
][
blockIdx
.
x
][
threadIdx
.
x
];
if
constexpr
(
need_fence
)
{
st_flag_release
(
peer_counter_ptr
,
val
);
while
(
ld_flag_acquire
(
self_counter_ptr
)
!=
val
);
if
constexpr
(
!
final_sync
)
{
st_flag_release
(
peer_counter_ptr
,
flag
);
while
(
ld_flag_acquire
(
self_counter_ptr
)
!=
flag
);
}
else
{
st_flag_volatile
(
peer_counter_ptr
,
val
);
while
(
ld_flag_volatile
(
self_counter_ptr
)
!=
val
);
st_flag_volatile
(
peer_counter_ptr
,
flag
);
while
(
ld_flag_volatile
(
self_counter_ptr
)
!=
flag
);
}
}
if
constexpr
(
is_start
||
need_fence
)
__syncthreads
();
if
constexpr
(
!
final_sync
)
__syncthreads
();
// use one thread to update flag
if
(
threadIdx
.
x
==
0
)
self_sg
->
_flag
[
blockIdx
.
x
]
=
flag
;
}
#else
template
<
int
ngpus
>
DINLINE
void
start_sync
(
const
RankSignals
&
sg
,
Signal
*
self_sg
,
int
rank
)
{
uint32_t
flag
=
self_sg
->
_flag
[
blockIdx
.
x
]
+
1
;
if
(
threadIdx
.
x
<
ngpus
)
{
// simultaneously write to the corresponding flag of all ranks.
// Latency = 1 p2p write
// __scoped_atomic_store_n(&sg.signals[threadIdx.x]->start[blockIdx.x][rank],
// flag, __ATOMIC_RELAXED, __MEMORY_SCOPE_SYSTEM);
// // wait until we got true from all ranks
// while (__scoped_atomic_load_n(&self_sg->start[blockIdx.x][threadIdx.x],
// __ATOMIC_RELAXED,
// __MEMORY_SCOPE_DEVICE) < flag);
__atomic_store_n
(
&
sg
.
signals
[
threadIdx
.
x
]
->
start
[
blockIdx
.
x
][
rank
],
flag
,
__ATOMIC_RELAXED
);
// wait until we got true from all ranks
while
(
__atomic_load_n
(
&
self_sg
->
start
[
blockIdx
.
x
][
threadIdx
.
x
],
__ATOMIC_RELAXED
)
<
flag
);
}
__syncthreads
();
// use one thread to update flag
if
(
threadIdx
.
x
==
0
)
self_sg
->
_flag
[
blockIdx
.
x
]
=
flag
;
}
template
<
int
ngpus
,
bool
final_sync
=
false
>
DINLINE
void
end_sync
(
const
RankSignals
&
sg
,
Signal
*
self_sg
,
int
rank
)
{
__syncthreads
();
uint32_t
flag
=
self_sg
->
_flag
[
blockIdx
.
x
]
+
1
;
if
(
threadIdx
.
x
<
ngpus
)
{
// simultaneously write to the corresponding flag of all ranks.
// Latency = 1 p2p write
// __scoped_atomic_store_n(&sg.signals[threadIdx.x]->end[blockIdx.x][rank],
// flag,
// final_sync ? __ATOMIC_RELAXED : __ATOMIC_RELEASE,
// __MEMORY_SCOPE_SYSTEM);
// // wait until we got true from all ranks
// while (
// __scoped_atomic_load_n(&self_sg->end[blockIdx.x][threadIdx.x],
// final_sync ? __ATOMIC_RELAXED : __ATOMIC_ACQUIRE,
// __MEMORY_SCOPE_DEVICE) < flag);
__atomic_store_n
(
&
sg
.
signals
[
threadIdx
.
x
]
->
end
[
blockIdx
.
x
][
rank
],
flag
,
final_sync
?
__ATOMIC_RELAXED
:
__ATOMIC_RELEASE
);
// wait until we got true from all ranks
while
(
__atomic_load_n
(
&
self_sg
->
end
[
blockIdx
.
x
][
threadIdx
.
x
],
final_sync
?
__ATOMIC_RELAXED
:
__ATOMIC_ACQUIRE
)
<
flag
);
}
if
constexpr
(
!
final_sync
)
__syncthreads
();
// use one thread to update flag
if
(
threadIdx
.
x
==
0
)
self_sg
->
_flag
[
blockIdx
.
x
]
=
flag
;
}
#endif
template
<
typename
P
,
int
ngpus
,
typename
A
>
DINLINE
P
packed_reduce
(
const
P
*
ptrs
[],
int
idx
)
{
A
tmp
=
upcast
(
ptrs
[
0
][
idx
]);
...
...
@@ -220,13 +311,13 @@ __global__ void __launch_bounds__(512, 1)
// note: we don't reorder the address so the accumulation order is the same
// for all ranks, ensuring bitwise identical results
auto
dp
=
*
_dp
;
multi_gpu_barrier
<
ngpus
,
true
>
(
sg
,
self_sg
,
rank
);
start_sync
<
ngpus
>
(
sg
,
self_sg
,
rank
);
// do the actual reduction
for
(
int
idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
idx
<
size
;
idx
+=
gridDim
.
x
*
blockDim
.
x
)
{
((
P
*
)
result
)[
idx
]
=
packed_reduce
<
P
,
ngpus
,
A
>
((
const
P
**
)
&
dp
.
ptrs
[
0
],
idx
);
}
multi_gpu_barrier
<
ngpus
,
fals
e
>
(
sg
,
self_sg
,
rank
);
end_sync
<
ngpus
,
tru
e
>
(
sg
,
self_sg
,
rank
);
}
template
<
typename
P
>
...
...
@@ -255,12 +346,13 @@ __global__ void __launch_bounds__(512, 1)
tmps
[
i
]
=
get_tmp_buf
<
P
>
(
sg
.
signals
[
target
]);
}
auto
tmp_out
=
tmps
[
0
];
multi_gpu_barrier
<
ngpus
,
true
>
(
sg
,
self_sg
,
rank
);
start_sync
<
ngpus
>
(
sg
,
self_sg
,
rank
);
// stage 1: reduce scatter
for
(
int
idx
=
start
+
tid
;
idx
<
end
;
idx
+=
stride
)
{
tmp_out
[
idx
-
start
]
=
packed_reduce
<
P
,
ngpus
,
A
>
(
ptrs
,
idx
);
}
multi_gpu_barrier
<
ngpus
,
false
,
true
>
(
sg
,
self_sg
,
rank
);
end_sync
<
ngpus
>
(
sg
,
self_sg
,
rank
);
// stage 2: allgather. Note: it's important to match the tid between
// the two stages, because visibility across devices is only guaranteed
...
...
@@ -290,7 +382,7 @@ class CustomAllreduce {
bool
full_nvlink_
;
RankSignals
sg_
;
// Stores an map from a pointer to its peer point
t
ers from all ranks.
// Stores an map from a pointer to its peer pointers from all ranks.
std
::
unordered_map
<
void
*
,
RankData
*>
buffers_
;
Signal
*
self_sg_
;
...
...
@@ -361,8 +453,7 @@ class CustomAllreduce {
void
*
base_ptr
;
// note: must share the base address of each allocation, or we get wrong
// address
if
(
cuPointerGetAttribute
(
&
base_ptr
,
CU_POINTER_ATTRIBUTE_RANGE_START_ADDR
,
if
(
cuPointerGetAttribute
(
&
base_ptr
,
rangeStartAddrAttr
,
(
CUdeviceptr
)
ptr
)
!=
CUDA_SUCCESS
)
throw
std
::
runtime_error
(
"failed to get pointer attr"
);
CUDACHECK
(
cudaIpcGetMemHandle
(
...
...
@@ -439,7 +530,7 @@ class CustomAllreduce {
*/
template
<
typename
T
>
void
allreduce
(
cudaStream_t
stream
,
T
*
input
,
T
*
output
,
int
size
,
int
threads
=
512
,
int
block_limit
=
36
)
{
int
threads
=
512
,
int
block_limit
=
defaultBlockLimit
)
{
auto
d
=
packed_t
<
T
>::
P
::
size
;
if
(
size
%
d
!=
0
)
throw
std
::
runtime_error
(
...
...
@@ -473,8 +564,6 @@ class CustomAllreduce {
#define KL(ngpus, name) \
name<T, ngpus><<<blocks, threads, 0, stream>>>(ptrs, sg_, self_sg_, output, \
rank_, size);
// TODO(hanzhi713): Threshold is different for A100 and H100.
// Add per device threshold.
#define REDUCE_CASE(ngpus) \
case ngpus: { \
if (world_size_ == 2) { \
...
...
@@ -497,7 +586,8 @@ class CustomAllreduce {
REDUCE_CASE
(
8
)
default:
throw
std
::
runtime_error
(
"custom allreduce only supports num gpus in (2,4,6,8). Actual num "
"custom allreduce only supports num gpus in (2,4,6,8). Actual "
"num "
"gpus = "
+
std
::
to_string
(
world_size_
));
}
...
...
csrc/custom_all_reduce_test.cu
View file @
469e903b
...
...
@@ -20,9 +20,16 @@
#include <vector>
#include "cuda_profiler_api.h"
#include "custom_all_reduce.cuh"
#include "mpi.h"
#include "nccl.h"
#ifdef USE_ROCM
#include <hip/hip_bf16.h>
typedef
__hip_bfloat16
nv_bfloat16
;
#include "rccl/rccl.h"
#include "custom_all_reduce_hip.cuh"
#else
#include "nccl.h"
#include "custom_all_reduce.cuh"
#endif
#define MPICHECK(cmd) \
do { \
...
...
@@ -44,13 +51,16 @@
} while (0)
__global__
void
dummy_kernel
()
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700
for
(
int
i
=
0
;
i
<
100
;
i
++
)
__nanosleep
(
1000000
);
// 100ms
#else
#ifdef USE_ROCM
for
(
int
i
=
0
;
i
<
100
;
i
++
)
{
long
long
int
start
=
clock64
();
while
(
clock64
()
-
start
<
150000000
);
// approximately 98.4ms on P40
uint64_t
start
=
wall_clock64
();
uint64_t
cycles_elapsed
;
do
{
cycles_elapsed
=
wall_clock64
()
-
start
;
}
while
(
cycles_elapsed
<
100
);
}
#else
for
(
int
i
=
0
;
i
<
100
;
i
++
)
__nanosleep
(
1000000
);
// 100ms
#endif
}
...
...
@@ -121,8 +131,14 @@ void run(int myRank, int nRanks, ncclComm_t& comm, int threads, int block_limit,
* registration, they are allocated and registered together in the test for
* convenience.
*/
#ifdef USE_ROCM
CUDACHECK
(
hipExtMallocWithFlags
(
(
void
**
)
&
buffer
,
2
*
data_size
*
sizeof
(
T
)
+
sizeof
(
vllm
::
Signal
),
hipDeviceMallocUncached
));
#else
CUDACHECK
(
cudaMalloc
(
&
buffer
,
2
*
data_size
*
sizeof
(
T
)
+
sizeof
(
vllm
::
Signal
)));
#endif
CUDACHECK
(
cudaMemset
(
buffer
,
0
,
2
*
data_size
*
sizeof
(
T
)
+
sizeof
(
vllm
::
Signal
)));
CUDACHECK
(
cudaMalloc
(
&
self_data_copy
,
data_size
*
sizeof
(
T
)));
...
...
@@ -135,26 +151,24 @@ void run(int myRank, int nRanks, ncclComm_t& comm, int threads, int block_limit,
void
*
rank_data
;
size_t
rank_data_sz
=
16
*
1024
*
1024
;
CUDACHECK
(
cudaMalloc
(
&
rank_data
,
rank_data_sz
));
vllm
::
Signal
*
ipc_ptrs
[
8
];
for
(
int
i
=
0
;
i
<
nRanks
;
i
++
)
{
if
(
i
==
myRank
)
ipc_ptrs
[
i
]
=
buffer
;
else
CUDACHECK
(
cudaIpcOpenMemHandle
((
void
**
)
&
ipc_ptrs
[
i
],
data_handles
[
i
],
cudaIpcMemLazyEnablePeerAccess
));
}
vllm
::
CustomAllreduce
fa
(
ipc_ptrs
,
rank_data
,
rank_data_sz
,
myRank
,
nRanks
);
std
::
vector
<
int64_t
>
offsets
(
nRanks
,
0
);
vllm
::
CustomAllreduce
fa
(
buffer
,
rank_data
,
rank_data_sz
,
data_handles
,
offsets
,
myRank
);
auto
*
self_data
=
reinterpret_cast
<
T
*>
(
reinterpret_cast
<
char
*>
(
buffer
)
+
sizeof
(
vllm
::
Signal
)
+
data_size
*
sizeof
(
T
));
// hack buffer registration
{
void
*
data
[
8
];
std
::
vector
<
std
::
string
>
handles
;
handles
.
reserve
(
nRanks
);
for
(
int
i
=
0
;
i
<
nRanks
;
i
++
)
{
data
[
i
]
=
((
char
*
)
ipc_ptrs
[
i
])
+
sizeof
(
vllm
::
Signal
)
+
data_size
*
sizeof
(
T
);
char
*
begin
=
(
char
*
)
&
data_handles
[
i
];
char
*
end
=
(
char
*
)
&
data_handles
[
i
+
1
];
handles
.
emplace_back
(
begin
,
end
);
}
fa
.
register_buffer
(
data
);
std
::
vector
<
int64_t
>
offsets
(
nRanks
,
sizeof
(
vllm
::
Signal
)
+
data_size
*
sizeof
(
T
));
fa
.
register_buffer
(
handles
,
offsets
,
self_data
);
}
double
*
ground_truth
;
...
...
@@ -266,14 +280,14 @@ void run(int myRank, int nRanks, ncclComm_t& comm, int threads, int block_limit,
if
(
diff
>=
4e-2
)
{
printf
(
"Rank %d: Verification mismatch at %lld: %f != (my) %f, gt=%f
\n
"
,
myRank
,
j
,
nccl_result
[
j
],
my_result
[
j
],
ground_truth
[
j
]);
myRank
,
j
,
nccl_result
[
j
],
my_result
[
j
],
ground_truth
[
j
]);
break
;
}
}
}
if
(
myRank
==
0
)
printf
(
"Test passed: nGPUs:%d, sz (kb): %d, %d, %d
\n
"
,
nRanks
,
data_size
*
sizeof
(
T
)
/
1024
,
threads
,
block_limit
);
data_size
*
sizeof
(
T
)
/
1024
,
threads
,
block_limit
);
// long double nccl_diffs = 0.0;
// long double my_diffs = 0.0;
// for (int j = 0; j < data_size; j++) {
...
...
@@ -306,24 +320,27 @@ int main(int argc, char** argv) {
ncclComm_t
comm
;
if
(
myRank
==
0
)
ncclGetUniqueId
(
&
id
);
MPICHECK
(
MPI_Bcast
(
static_cast
<
void
*>
(
&
id
),
sizeof
(
id
),
MPI_BYTE
,
0
,
MPI_COMM_WORLD
));
MPI_COMM_WORLD
));
NCCLCHECK
(
ncclCommInitRank
(
&
comm
,
nRanks
,
id
,
myRank
));
bool
performance_test
=
true
;
cudaProfilerStart
();
// Uncomment to scan through different block size configs.
// for (int threads : {256, 512, 1024}) {
// for (int threads : {256, 512}) {
// for (int block_limit = 16; block_limit < 112; block_limit += 4) {
// run<half>(myRank, nRanks, comm, threads, block_limit, 1024 * 1024,
// performance_test);
// run<half>(myRank, nRanks, comm, threads, block_limit, 4096 * 1024);
// }
// }
// Scan through different sizes to test performance.
#ifdef USE_ROCM
for
(
int
sz
=
512
;
sz
<=
(
8
<<
20
);
sz
*=
2
)
{
run
<
half
>
(
myRank
,
nRanks
,
comm
,
512
,
16
,
sz
+
8
*
47
,
performance_test
);
}
#else
for
(
int
sz
=
512
;
sz
<=
(
8
<<
20
);
sz
*=
2
)
{
run
<
half
>
(
myRank
,
nRanks
,
comm
,
512
,
36
,
sz
+
8
*
47
,
performance_test
);
}
#endif
cudaProfilerStop
();
MPICHECK
(
MPI_Finalize
());
return
EXIT_SUCCESS
;
}
}
csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c2x.hpp
View file @
469e903b
...
...
@@ -122,8 +122,8 @@ struct ScaledEpilogue
auto
a_args
=
SUPER
::
template
args_from_tensor
<
ScaleA
,
float
>(
a_scales
);
auto
b_args
=
SUPER
::
template
args_from_tensor
<
ScaleB
,
float
>(
b_scales
);
typename
EVTCompute0
::
Arguments
evt0_args
{
b_args
};
return
ArgumentType
{
a_args
,
evt0_args
};
typename
EVTCompute0
::
Arguments
evt0_args
{
b_args
,
{},
{}
};
return
ArgumentType
{
a_args
,
evt0_args
,
{}
};
}
};
...
...
@@ -167,8 +167,8 @@ struct ScaledEpilogueBias
auto
b_args
=
SUPER
::
template
args_from_tensor
<
ScaleB
,
float
>(
b_scales
);
auto
bias_args
=
SUPER
::
template
args_from_tensor
<
Bias
,
ElementD
>(
bias
);
typename
EVTCompute0
::
Arguments
evt0_args
{
b_args
};
return
ArgumentType
{
a_args
,
evt0_args
,
bias_args
};
typename
EVTCompute0
::
Arguments
evt0_args
{
b_args
,
{},
{}
};
return
ArgumentType
{
a_args
,
evt0_args
,
bias_args
,
{}
};
}
};
...
...
@@ -230,9 +230,10 @@ struct ScaledEpilogueBiasAzp
auto
azp_adj_args
=
SUPER
::
template
args_from_tensor
<
AzpWithAdj
,
int32_t
>(
azp_adj
);
typename
EVTComputeAzp
::
Arguments
evt_azp_args
{{},
azp_adj_args
};
typename
EVTComputeScaleB
::
Arguments
evt_scale_b_args
{
b_args
,
evt_azp_args
};
return
ArgumentType
{
a_args
,
evt_scale_b_args
,
bias_args
};
typename
EVTComputeAzp
::
Arguments
evt_azp_args
{{},
azp_adj_args
,
{}};
typename
EVTComputeScaleB
::
Arguments
evt_scale_b_args
{
b_args
,
evt_azp_args
,
{}};
return
ArgumentType
{
a_args
,
evt_scale_b_args
,
bias_args
,
{}};
}
};
...
...
@@ -309,11 +310,12 @@ struct ScaledEpilogueBiasAzpToken
auto
azp_adj_args
=
SUPER
::
template
args_from_tensor
<
AzpAdj
,
int32_t
>(
azp_adj
);
typename
EVTComputeAzp
::
Arguments
evt_azp_args
{
azp_args
,
azp_adj_args
};
typename
EVTComputeAcc
::
Arguments
evt_acc_args
{{},
evt_azp_args
};
typename
EVTComputeScaleB
::
Arguments
evt_scale_b_args
{
b_args
,
evt_acc_args
};
return
ArgumentType
{
a_args
,
evt_scale_b_args
,
bias_args
};
typename
EVTComputeAzp
::
Arguments
evt_azp_args
{
azp_args
,
azp_adj_args
,
{}};
typename
EVTComputeAcc
::
Arguments
evt_acc_args
{{},
evt_azp_args
,
{}};
typename
EVTComputeScaleB
::
Arguments
evt_scale_b_args
{
b_args
,
evt_acc_args
,
{}};
return
ArgumentType
{
a_args
,
evt_scale_b_args
,
bias_args
,
{}};
}
};
};
// namespace vllm::c2x
\ No newline at end of file
};
// namespace vllm::c2x
csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp
View file @
469e903b
...
...
@@ -22,7 +22,7 @@ struct identity {
T
operator
()(
T
lhs
)
const
{
return
lhs
;
}
};
template
<
typename
ElementAcc
,
typename
ElementD
,
typename
EpilogueDescriptor
>
template
<
typename
ElementAcc
,
typename
ElementD
,
typename
TileShape
>
struct
TrivialEpilogue
{
private:
using
Accum
=
cutlass
::
epilogue
::
fusion
::
Sm90AccFetch
;
...
...
@@ -44,32 +44,30 @@ struct TrivialEpilogue {
* This class provides the common load descriptors for the
* ScaledEpilogue[...] classes
*/
template
<
typename
ElementAcc
,
typename
ElementD
,
typename
EpilogueDescriptor
>
template
<
typename
ElementAcc
,
typename
ElementD
,
typename
TileShape
>
struct
ScaledEpilogueBase
{
protected:
using
Accum
=
cutlass
::
epilogue
::
fusion
::
Sm90AccFetch
;
template
<
typename
T
>
using
ColOrScalarLoad
=
cutlass
::
epilogue
::
fusion
::
Sm90ColOrScalarBroadcast
<
0
/*Stages*/
,
typename
EpilogueDescriptor
::
TileShape
,
T
,
Stride
<
Int
<
1
>
,
Int
<
0
>
,
Int
<
0
>>>
;
0
/*Stages*/
,
TileShape
,
T
,
Stride
<
Int
<
1
>
,
Int
<
0
>
,
Int
<
0
>>>
;
template
<
typename
T
>
using
RowOrScalarLoad
=
cutlass
::
epilogue
::
fusion
::
Sm90RowOrScalarBroadcast
<
0
/*Stages*/
,
typename
EpilogueDescriptor
::
TileShape
,
T
,
Stride
<
Int
<
0
>
,
Int
<
1
>
,
Int
<
0
>>>
;
0
/*Stages*/
,
TileShape
,
T
,
Stride
<
Int
<
0
>
,
Int
<
1
>
,
Int
<
0
>>>
;
// Don't want to support nullptr by default
template
<
typename
T
,
bool
EnableNullPtr
=
false
>
using
ColLoad
=
cutlass
::
epilogue
::
fusion
::
Sm90ColBroadcast
<
0
/*Stages*/
,
typename
EpilogueDescriptor
::
TileShape
,
T
,
T
,
Stride
<
Int
<
1
>
,
Int
<
0
>
,
Int
<
0
>>
,
128
/
sizeof_bits_v
<
T
>
,
EnableNullPtr
>
;
0
/*Stages*/
,
TileShape
,
T
,
T
,
Stride
<
Int
<
1
>
,
Int
<
0
>
,
Int
<
0
>>
,
128
/
sizeof_bits_v
<
T
>
,
EnableNullPtr
>
;
// Don't want to support nullptr by default
template
<
typename
T
,
bool
EnableNullPtr
=
false
>
using
RowLoad
=
cutlass
::
epilogue
::
fusion
::
Sm90RowBroadcast
<
0
/*Stages*/
,
typename
EpilogueDescriptor
::
TileShape
,
T
,
T
,
Stride
<
Int
<
0
>
,
Int
<
1
>
,
Int
<
0
>>
,
128
/
sizeof_bits_v
<
T
>
,
EnableNullPtr
>
;
0
/*Stages*/
,
TileShape
,
T
,
T
,
Stride
<
Int
<
0
>
,
Int
<
1
>
,
Int
<
0
>>
,
128
/
sizeof_bits_v
<
T
>
,
EnableNullPtr
>
;
// This utility function constructs the arguments for the load descriptors
// from a tensor. It can handle both row and column, as well as row/column or
...
...
@@ -116,11 +114,11 @@ struct ScaledEpilogueBase {
the A and B operands respectively. These scales may be either per-tensor or
per row or column.
*/
template
<
typename
ElementAcc
,
typename
ElementD
,
typename
EpilogueDescriptor
>
template
<
typename
ElementAcc
,
typename
ElementD
,
typename
TileShape
>
struct
ScaledEpilogue
:
private
ScaledEpilogueBase
<
ElementAcc
,
ElementD
,
EpilogueDescriptor
>
{
:
private
ScaledEpilogueBase
<
ElementAcc
,
ElementD
,
TileShape
>
{
private:
using
SUPER
=
ScaledEpilogueBase
<
ElementAcc
,
ElementD
,
EpilogueDescriptor
>
;
using
SUPER
=
ScaledEpilogueBase
<
ElementAcc
,
ElementD
,
TileShape
>
;
using
Accum
=
typename
SUPER
::
Accum
;
using
ScaleA
=
typename
SUPER
::
template
ColOrScalarLoad
<
float
>;
using
ScaleB
=
typename
SUPER
::
template
RowOrScalarLoad
<
float
>;
...
...
@@ -146,8 +144,8 @@ struct ScaledEpilogue
auto
a_args
=
SUPER
::
template
args_from_tensor
<
ScaleA
,
float
>(
a_scales
);
auto
b_args
=
SUPER
::
template
args_from_tensor
<
ScaleB
,
float
>(
b_scales
);
typename
EVTCompute0
::
Arguments
evt0_args
{
b_args
};
return
ArgumentType
{
a_args
,
evt0_args
};
typename
EVTCompute0
::
Arguments
evt0_args
{
b_args
,
{},
{}
};
return
ArgumentType
{
a_args
,
evt0_args
,
{}
};
}
};
...
...
@@ -160,11 +158,11 @@ struct ScaledEpilogue
* The bias tensor must be per-output channel.
* ScaleA and ScaleB can be per-tensor or per-token/per-channel.
*/
template
<
typename
ElementAcc
,
typename
ElementD
,
typename
EpilogueDescriptor
>
template
<
typename
ElementAcc
,
typename
ElementD
,
typename
TileShape
>
struct
ScaledEpilogueBias
:
private
ScaledEpilogueBase
<
ElementAcc
,
ElementD
,
EpilogueDescriptor
>
{
:
private
ScaledEpilogueBase
<
ElementAcc
,
ElementD
,
TileShape
>
{
private:
using
SUPER
=
ScaledEpilogueBase
<
ElementAcc
,
ElementD
,
EpilogueDescriptor
>
;
using
SUPER
=
ScaledEpilogueBase
<
ElementAcc
,
ElementD
,
TileShape
>
;
using
Accum
=
typename
SUPER
::
Accum
;
using
ScaleA
=
typename
SUPER
::
template
ColOrScalarLoad
<
float
>;
using
ScaleB
=
typename
SUPER
::
template
RowOrScalarLoad
<
float
>;
...
...
@@ -193,8 +191,8 @@ struct ScaledEpilogueBias
auto
b_args
=
SUPER
::
template
args_from_tensor
<
ScaleB
,
float
>(
b_scales
);
auto
bias_args
=
SUPER
::
template
args_from_tensor
<
Bias
,
ElementD
>(
bias
);
typename
EVTCompute0
::
Arguments
evt0_args
{
b_args
};
return
ArgumentType
{
a_args
,
evt0_args
,
bias_args
};
typename
EVTCompute0
::
Arguments
evt0_args
{
b_args
,
{},
{}
};
return
ArgumentType
{
a_args
,
evt0_args
,
bias_args
,
{}
};
}
};
...
...
@@ -203,11 +201,11 @@ struct ScaledEpilogueBias
* bias is a column vector instead of a row vector. Useful e.g. if we are
* computing a GEMM via C^T += B^T A^T. This happens in the 2:4 sparse kernels.
*/
template
<
typename
ElementAcc
,
typename
ElementD
,
typename
EpilogueDescriptor
>
template
<
typename
ElementAcc
,
typename
ElementD
,
typename
TileShape
>
struct
ScaledEpilogueColumnBias
:
private
ScaledEpilogueBase
<
ElementAcc
,
ElementD
,
EpilogueDescriptor
>
{
:
private
ScaledEpilogueBase
<
ElementAcc
,
ElementD
,
TileShape
>
{
private:
using
SUPER
=
ScaledEpilogueBase
<
ElementAcc
,
ElementD
,
EpilogueDescriptor
>
;
using
SUPER
=
ScaledEpilogueBase
<
ElementAcc
,
ElementD
,
TileShape
>
;
using
Accum
=
typename
SUPER
::
Accum
;
using
ScaleA
=
typename
SUPER
::
template
ColOrScalarLoad
<
float
>;
using
ScaleB
=
typename
SUPER
::
template
RowOrScalarLoad
<
float
>;
...
...
@@ -236,8 +234,8 @@ struct ScaledEpilogueColumnBias
auto
b_args
=
SUPER
::
template
args_from_tensor
<
ScaleB
,
float
>(
b_scales
);
auto
bias_args
=
SUPER
::
template
args_from_tensor
<
Bias
,
ElementD
>(
bias
);
typename
EVTCompute0
::
Arguments
evt0_args
{
b_args
};
return
ArgumentType
{
a_args
,
evt0_args
,
bias_args
};
typename
EVTCompute0
::
Arguments
evt0_args
{
b_args
,
{},
{}
};
return
ArgumentType
{
a_args
,
evt0_args
,
bias_args
,
{}
};
}
};
...
...
@@ -249,11 +247,11 @@ struct ScaledEpilogueColumnBias
*
* This epilogue also supports bias, which remains per-channel.
*/
template
<
typename
ElementAcc
,
typename
ElementD
,
typename
EpilogueDescriptor
>
template
<
typename
ElementAcc
,
typename
ElementD
,
typename
TileShape
>
struct
ScaledEpilogueBiasAzp
:
private
ScaledEpilogueBase
<
ElementAcc
,
ElementD
,
EpilogueDescriptor
>
{
:
private
ScaledEpilogueBase
<
ElementAcc
,
ElementD
,
TileShape
>
{
private:
using
SUPER
=
ScaledEpilogueBase
<
ElementAcc
,
ElementD
,
EpilogueDescriptor
>
;
using
SUPER
=
ScaledEpilogueBase
<
ElementAcc
,
ElementD
,
TileShape
>
;
using
Accum
=
typename
SUPER
::
Accum
;
using
ScaleA
=
typename
SUPER
::
template
ColOrScalarLoad
<
float
>;
using
ScaleB
=
typename
SUPER
::
template
RowOrScalarLoad
<
float
>;
...
...
@@ -297,9 +295,10 @@ struct ScaledEpilogueBiasAzp
auto
azp_adj_args
=
SUPER
::
template
args_from_tensor
<
AzpWithAdj
,
int32_t
>(
azp_adj
);
typename
EVTComputeAzp
::
Arguments
evt_azp_args
{{},
azp_adj_args
};
typename
EVTComputeScaleB
::
Arguments
evt_scale_b_args
{
b_args
,
evt_azp_args
};
return
ArgumentType
{
a_args
,
evt_scale_b_args
,
bias_args
};
typename
EVTComputeAzp
::
Arguments
evt_azp_args
{{},
azp_adj_args
,
{}};
typename
EVTComputeScaleB
::
Arguments
evt_scale_b_args
{
b_args
,
evt_azp_args
,
{}};
return
ArgumentType
{
a_args
,
evt_scale_b_args
,
bias_args
,
{}};
}
};
...
...
@@ -313,11 +312,11 @@ struct ScaledEpilogueBiasAzp
*
* This epilogue also supports bias, which remains per-channel.
*/
template
<
typename
ElementAcc
,
typename
ElementD
,
typename
EpilogueDescriptor
>
template
<
typename
ElementAcc
,
typename
ElementD
,
typename
TileShape
>
struct
ScaledEpilogueBiasAzpToken
:
private
ScaledEpilogueBase
<
ElementAcc
,
ElementD
,
EpilogueDescriptor
>
{
:
private
ScaledEpilogueBase
<
ElementAcc
,
ElementD
,
TileShape
>
{
private:
using
SUPER
=
ScaledEpilogueBase
<
ElementAcc
,
ElementD
,
EpilogueDescriptor
>
;
using
SUPER
=
ScaledEpilogueBase
<
ElementAcc
,
ElementD
,
TileShape
>
;
using
Accum
=
typename
SUPER
::
Accum
;
using
ScaleA
=
typename
SUPER
::
template
ColOrScalarLoad
<
float
>;
using
ScaleB
=
typename
SUPER
::
template
RowOrScalarLoad
<
float
>;
...
...
@@ -374,10 +373,11 @@ struct ScaledEpilogueBiasAzpToken
auto
azp_adj_args
=
SUPER
::
template
args_from_tensor
<
AzpAdj
,
int32_t
>(
azp_adj
);
typename
EVTComputeAzp
::
Arguments
evt_azp_args
{
azp_args
,
azp_adj_args
};
typename
EVTComputeAcc
::
Arguments
evt_acc_args
{{},
evt_azp_args
};
typename
EVTComputeScaleB
::
Arguments
evt_scale_b_args
{
b_args
,
evt_acc_args
};
return
ArgumentType
{
a_args
,
evt_scale_b_args
,
bias_args
};
typename
EVTComputeAzp
::
Arguments
evt_azp_args
{
azp_args
,
azp_adj_args
,
{}};
typename
EVTComputeAcc
::
Arguments
evt_acc_args
{{},
evt_azp_args
,
{}};
typename
EVTComputeScaleB
::
Arguments
evt_scale_b_args
{
b_args
,
evt_acc_args
,
{}};
return
ArgumentType
{
a_args
,
evt_scale_b_args
,
bias_args
,
{}};
}
};
...
...
csrc/cutlass_extensions/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8_blockwise_scaling.hpp
View file @
469e903b
...
...
@@ -402,7 +402,7 @@ struct CollectiveMma<
// TODO: test `scale_copy_a` with `ScaleMsPerTile` < 128
TiledCopy
scale_copy_a
=
make_tiled_copy
(
SmemBlockScalingCopyAtomA
{},
Layout
<
Shape
<
_32
,
_1
>>
{},
Layout
<
Shape
<
_4
,
_1
>>
{});
// (1,1,1)
Layout
<
Shape
<
_32
>>
{},
Layout
<
Shape
<
_1
>>
{});
// (1,1,1)
TiledCopy
scale_copy_b
=
make_tiled_copy
(
SmemBlockScalingCopyAtomB
{},
Layout
<
Shape
<
_1
>>
{},
Layout
<
Shape
<
_1
>>
{});
// (1,1,1)
ThrCopy
thr_scale_copy_a
=
scale_copy_a
.
get_slice
(
threadIdx
.
x
);
...
...
csrc/cutlass_extensions/vllm_cutlass_library_extension.py
View file @
469e903b
# SPDX-License-Identifier: Apache-2.0
import
enum
from
typing
import
Dict
,
Union
from
typing
import
Union
from
cutlass_library
import
*
...
...
@@ -21,7 +21,7 @@ class MixedInputKernelScheduleType(enum.Enum):
TmaWarpSpecializedCooperative
=
enum_auto
()
VLLMDataTypeNames
:
D
ict
[
Union
[
VLLMDataType
,
DataType
],
str
]
=
{
VLLMDataTypeNames
:
d
ict
[
Union
[
VLLMDataType
,
DataType
],
str
]
=
{
**
DataTypeNames
,
# type: ignore
**
{
VLLMDataType
.
u4b8
:
"u4b8"
,
...
...
@@ -29,7 +29,7 @@ VLLMDataTypeNames: Dict[Union[VLLMDataType, DataType], str] = {
}
}
VLLMDataTypeTag
:
D
ict
[
Union
[
VLLMDataType
,
DataType
],
str
]
=
{
VLLMDataTypeTag
:
d
ict
[
Union
[
VLLMDataType
,
DataType
],
str
]
=
{
**
DataTypeTag
,
# type: ignore
**
{
VLLMDataType
.
u4b8
:
"cutlass::vllm_uint4b8_t"
,
...
...
@@ -37,7 +37,7 @@ VLLMDataTypeTag: Dict[Union[VLLMDataType, DataType], str] = {
}
}
VLLMDataTypeSize
:
D
ict
[
Union
[
VLLMDataType
,
DataType
],
int
]
=
{
VLLMDataTypeSize
:
d
ict
[
Union
[
VLLMDataType
,
DataType
],
int
]
=
{
**
DataTypeSize
,
# type: ignore
**
{
VLLMDataType
.
u4b8
:
4
,
...
...
@@ -45,7 +45,7 @@ VLLMDataTypeSize: Dict[Union[VLLMDataType, DataType], int] = {
}
}
VLLMDataTypeVLLMScalarTypeTag
:
D
ict
[
Union
[
VLLMDataType
,
DataType
],
str
]
=
{
VLLMDataTypeVLLMScalarTypeTag
:
d
ict
[
Union
[
VLLMDataType
,
DataType
],
str
]
=
{
VLLMDataType
.
u4b8
:
"vllm::kU4B8"
,
VLLMDataType
.
u8b128
:
"vllm::kU8B128"
,
DataType
.
u4
:
"vllm::kU4"
,
...
...
@@ -56,7 +56,7 @@ VLLMDataTypeVLLMScalarTypeTag: Dict[Union[VLLMDataType, DataType], str] = {
DataType
.
bf16
:
"vllm::kBfloat16"
,
}
VLLMDataTypeTorchDataTypeTag
:
D
ict
[
Union
[
VLLMDataType
,
DataType
],
str
]
=
{
VLLMDataTypeTorchDataTypeTag
:
d
ict
[
Union
[
VLLMDataType
,
DataType
],
str
]
=
{
DataType
.
u8
:
"at::ScalarType::Byte"
,
DataType
.
s8
:
"at::ScalarType::Char"
,
DataType
.
e4m3
:
"at::ScalarType::Float8_e4m3fn"
,
...
...
@@ -66,7 +66,7 @@ VLLMDataTypeTorchDataTypeTag: Dict[Union[VLLMDataType, DataType], str] = {
DataType
.
f32
:
"at::ScalarType::Float"
,
}
VLLMKernelScheduleTag
:
D
ict
[
Union
[
VLLMKernelScheduleTag
:
d
ict
[
Union
[
MixedInputKernelScheduleType
,
KernelScheduleType
],
str
]
=
{
**
KernelScheduleTag
,
# type: ignore
**
{
...
...
csrc/dispatch_utils.h
View file @
469e903b
...
...
@@ -6,6 +6,11 @@
#include <torch/all.h>
// Need a special dispatch case macro since we will nest the FP8 dispatch.
// Instead of the usual 'scalar_t', this names the dispatched type 'fp8_t'.
#define AT_DISPATCH_FP8_CASE(enum_type, ...) \
AT_PRIVATE_CASE_TYPE_USING_HINT(enum_type, fp8_t, __VA_ARGS__)
#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \
...
...
@@ -14,17 +19,32 @@
#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))
// TODO(luka/varun): use FP8_TYPE macro after refactoring
#ifndef USE_ROCM
#define VLLM_DISPATCH_CASE_QUANT_TYPES(...) \
AT_DISPATCH_CASE(at::ScalarType::Float8_e4m3fn, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__)
#else
// ROCm devices might use either fn or fnuz, so set up dispatch table for both.
// A host-based check at runtime will create a preferred FP8 type for ROCm
// such that the correct kernel is dispatched.
#ifdef USE_ROCM
#define VLLM_DISPATCH_CASE_FP8_TYPES(...) \
AT_DISPATCH_FP8_CASE(at::ScalarType::Float8_e4m3fn, __VA_ARGS__) \
AT_DISPATCH_FP8_CASE(at::ScalarType::Float8_e4m3fnuz, __VA_ARGS__)
#define VLLM_DISPATCH_CASE_QUANT_TYPES(...) \
AT_DISPATCH_CASE(at::ScalarType::Float8_e4m3fn, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Float8_e4m3fnuz, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__)
#else
#define VLLM_DISPATCH_CASE_FP8_TYPES(...) \
AT_DISPATCH_FP8_CASE(at::ScalarType::Float8_e4m3fn, __VA_ARGS__)
#define VLLM_DISPATCH_CASE_QUANT_TYPES(...) \
AT_DISPATCH_CASE(at::ScalarType::Float8_e4m3fn, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__)
#endif
// When using this dispatch macro, the type is 'fp8_t' not 'scalar_t'.
// See AT_DISPATCH_FP8_CASE above.
#define VLLM_DISPATCH_FP8_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FP8_TYPES(__VA_ARGS__))
#define VLLM_DISPATCH_QUANT_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_QUANT_TYPES(__VA_ARGS__))
...
...
csrc/layernorm_quant_kernels.cu
View file @
469e903b
...
...
@@ -21,9 +21,9 @@
namespace
vllm
{
// TODO(woosuk): Further optimize this kernel.
template
<
typename
scalar_t
>
template
<
typename
scalar_t
,
typename
fp8_type
>
__global__
void
rms_norm_static_fp8_quant_kernel
(
FP8_TYPE
*
__restrict__
out
,
// [..., hidden_size]
fp8_type
*
__restrict__
out
,
// [..., hidden_size]
const
scalar_t
*
__restrict__
input
,
// [..., hidden_size]
const
scalar_t
*
__restrict__
weight
,
// [hidden_size]
const
float
*
__restrict__
scale
,
// [1]
...
...
@@ -52,7 +52,7 @@ __global__ void rms_norm_static_fp8_quant_kernel(
float
x
=
(
float
)
input
[
blockIdx
.
x
*
hidden_size
+
idx
];
float
const
out_norm
=
((
scalar_t
)(
x
*
s_variance
))
*
weight
[
idx
];
out
[
blockIdx
.
x
*
hidden_size
+
idx
]
=
scaled_fp8_conversion
<
true
>
(
out_norm
,
scale_inv
);
scaled_fp8_conversion
<
true
,
fp8_type
>
(
out_norm
,
scale_inv
);
}
}
...
...
@@ -60,10 +60,10 @@ __global__ void rms_norm_static_fp8_quant_kernel(
Additional optimizations we can make in this case are
packed and vectorized operations, which help with the
memory latency bottleneck. */
template
<
typename
scalar_t
,
int
width
>
template
<
typename
scalar_t
,
int
width
,
typename
fp8_type
>
__global__
std
::
enable_if_t
<
(
width
>
0
)
&&
_typeConvert
<
scalar_t
>::
exists
>
fused_add_rms_norm_static_fp8_quant_kernel
(
FP8_TYPE
*
__restrict__
out
,
// [..., hidden_size]
fp8_type
*
__restrict__
out
,
// [..., hidden_size]
scalar_t
*
__restrict__
input
,
// [..., hidden_size]
scalar_t
*
__restrict__
residual
,
// [..., hidden_size]
const
scalar_t
*
__restrict__
weight
,
// [hidden_size]
...
...
@@ -114,7 +114,7 @@ fused_add_rms_norm_static_fp8_quant_kernel(
#pragma unroll
for
(
int
i
=
0
;
i
<
width
;
++
i
)
{
out
[
id
*
width
+
i
]
=
scaled_fp8_conversion
<
true
>
(
float
(
temp
.
data
[
i
]),
scale_inv
);
scaled_fp8_conversion
<
true
,
fp8_type
>
(
float
(
temp
.
data
[
i
]),
scale_inv
);
}
}
}
...
...
@@ -122,10 +122,10 @@ fused_add_rms_norm_static_fp8_quant_kernel(
/* Generic fused_add_rms_norm_kernel
The width field is not used here but necessary for other specializations.
*/
template
<
typename
scalar_t
,
int
width
>
template
<
typename
scalar_t
,
int
width
,
typename
fp8_type
>
__global__
std
::
enable_if_t
<
(
width
==
0
)
||
!
_typeConvert
<
scalar_t
>::
exists
>
fused_add_rms_norm_static_fp8_quant_kernel
(
FP8_TYPE
*
__restrict__
out
,
// [..., hidden_size]
fp8_type
*
__restrict__
out
,
// [..., hidden_size]
scalar_t
*
__restrict__
input
,
// [..., hidden_size]
scalar_t
*
__restrict__
residual
,
// [..., hidden_size]
const
scalar_t
*
__restrict__
weight
,
// [hidden_size]
...
...
@@ -158,7 +158,7 @@ fused_add_rms_norm_static_fp8_quant_kernel(
float
x
=
(
float
)
residual
[
blockIdx
.
x
*
hidden_size
+
idx
];
float
const
out_norm
=
((
scalar_t
)(
x
*
s_variance
))
*
weight
[
idx
];
out
[
blockIdx
.
x
*
hidden_size
+
idx
]
=
scaled_fp8_conversion
<
true
>
(
out_norm
,
scale_inv
);
scaled_fp8_conversion
<
true
,
fp8_type
>
(
out_norm
,
scale_inv
);
}
}
...
...
@@ -176,25 +176,33 @@ void rms_norm_static_fp8_quant(torch::Tensor& out, // [..., hidden_size]
dim3
block
(
std
::
min
(
hidden_size
,
1024
));
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
input
));
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
VLLM_DISPATCH_FLOATING_TYPES
(
input
.
scalar_type
(),
"rms_norm_kernel"
,
[
&
]
{
vllm
::
rms_norm_static_fp8_quant_kernel
<
scalar_t
>
<<<
grid
,
block
,
0
,
stream
>>>
(
out
.
data_ptr
<
FP8_TYPE
>
(),
input
.
data_ptr
<
scalar_t
>
(),
weight
.
data_ptr
<
scalar_t
>
(),
scale
.
data_ptr
<
float
>
(),
epsilon
,
num_tokens
,
hidden_size
);
});
VLLM_DISPATCH_FLOATING_TYPES
(
input
.
scalar_type
(),
"rms_norm_kernel_scalar_type"
,
[
&
]
{
VLLM_DISPATCH_FP8_TYPES
(
out
.
scalar_type
(),
"rms_norm_kernel_fp8_type"
,
[
&
]
{
vllm
::
rms_norm_static_fp8_quant_kernel
<
scalar_t
,
fp8_t
>
<<<
grid
,
block
,
0
,
stream
>>>
(
out
.
data_ptr
<
fp8_t
>
(),
input
.
data_ptr
<
scalar_t
>
(),
weight
.
data_ptr
<
scalar_t
>
(),
scale
.
data_ptr
<
float
>
(),
epsilon
,
num_tokens
,
hidden_size
);
});
});
}
#define LAUNCH_FUSED_ADD_RMS_NORM(width) \
VLLM_DISPATCH_FLOATING_TYPES( \
input.scalar_type(), "fused_add_rms_norm_kernel", [&] { \
vllm::fused_add_rms_norm_static_fp8_quant_kernel<scalar_t, width> \
<<<grid, block, 0, stream>>>( \
out.data_ptr<FP8_TYPE>(), input.data_ptr<scalar_t>(), \
residual.data_ptr<scalar_t>(), weight.data_ptr<scalar_t>(), \
scale.data_ptr<float>(), epsilon, num_tokens, hidden_size); \
#define LAUNCH_FUSED_ADD_RMS_NORM(width) \
VLLM_DISPATCH_FLOATING_TYPES( \
input.scalar_type(), "fused_add_rms_norm_kernel_scalar_type", [&] { \
VLLM_DISPATCH_FP8_TYPES( \
out.scalar_type(), "fused_add_rms_norm_kernel_fp8_type", [&] { \
vllm::fused_add_rms_norm_static_fp8_quant_kernel<scalar_t, \
width, fp8_t> \
<<<grid, block, 0, stream>>>( \
out.data_ptr<fp8_t>(), input.data_ptr<scalar_t>(), \
residual.data_ptr<scalar_t>(), \
weight.data_ptr<scalar_t>(), scale.data_ptr<float>(), \
epsilon, num_tokens, hidden_size); \
}); \
});
void
fused_add_rms_norm_static_fp8_quant
(
torch
::
Tensor
&
out
,
// [..., hidden_size],
torch
::
Tensor
&
input
,
// [..., hidden_size]
...
...
csrc/moe/moe_ops.h
View file @
469e903b
...
...
@@ -24,3 +24,14 @@ void sgl_moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
torch
::
Tensor
sorted_token_ids
,
torch
::
Tensor
experts_ids
,
torch
::
Tensor
num_tokens_post_pad
);
#ifndef USE_ROCM
torch
::
Tensor
moe_wna16_gemm
(
torch
::
Tensor
input
,
torch
::
Tensor
output
,
torch
::
Tensor
b_qweight
,
torch
::
Tensor
b_scales
,
std
::
optional
<
torch
::
Tensor
>
b_qzeros
,
std
::
optional
<
torch
::
Tensor
>
topk_weights
,
torch
::
Tensor
sorted_token_ids
,
torch
::
Tensor
expert_ids
,
torch
::
Tensor
num_tokens_post_pad
,
int64_t
top_k
,
int64_t
BLOCK_SIZE_M
,
int64_t
BLOCK_SIZE_N
,
int64_t
BLOCK_SIZE_K
,
int64_t
bit
);
#endif
\ No newline at end of file
Prev
1
2
3
4
5
6
7
8
9
…
27
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment