"...git@developer.sourcefind.cn:2222/OpenDAS/vllm_cscc.git" did not exist on "38b14e8b06c4e24e13be1c3e994fee307e4d3c6e"
Commit 6dc7aa42 authored by zhuwenwen's avatar zhuwenwen
Browse files

remove fatrelu_and_mul

parent 02c3e313
......@@ -107,41 +107,41 @@ __device__ __forceinline__ T gelu_tanh_kernel(const T& x) {
return (T)(0.5f * f * (1.0f + ::tanhf(inner)));
}
template <typename T>
__device__ __forceinline__ T fatrelu_kernel(const T& x, const float threshold) {
const float f = (float)x;
return (T)(f > threshold ? f : 0.0f);
}
template <typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&, const float)>
__global__ void act_and_mul_kernel_with_param(
scalar_t* __restrict__ out, const scalar_t* __restrict__ input, const int d,
const float param) {
const int64_t token_idx = blockIdx.x;
for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
const scalar_t x = VLLM_LDG(&input[token_idx * 2 * d + idx]);
const scalar_t y = VLLM_LDG(&input[token_idx * 2 * d + d + idx]);
out[token_idx * d + idx] = ACT_FN(x, param) * y;
}
}
// template <typename T>
// __device__ __forceinline__ T fatrelu_kernel(const T& x, const float threshold) {
// const float f = (float)x;
// return (T)(f > threshold ? f : 0.0f);
// }
// template <typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&, const float)>
// __global__ void act_and_mul_kernel_with_param(
// scalar_t* __restrict__ out, const scalar_t* __restrict__ input, const int d,
// const float param) {
// const int64_t token_idx = blockIdx.x;
// for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
// const scalar_t x = VLLM_LDG(&input[token_idx * 2 * d + idx]);
// const scalar_t y = VLLM_LDG(&input[token_idx * 2 * d + d + idx]);
// out[token_idx * d + idx] = ACT_FN(x, param) * y;
// }
// }
} // namespace vllm
#define LAUNCH_ACTIVATION_GATE_KERNEL_WITH_PARAM(KERNEL, PARAM) \
int d = input.size(-1) / 2; \
int64_t num_tokens = input.numel() / input.size(-1); \
dim3 grid(num_tokens); \
dim3 block(std::min(d, 1024)); \
const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
VLLM_DISPATCH_FLOATING_TYPES( \
input.scalar_type(), "act_and_mul_kernel_with_param", [&] { \
vllm::act_and_mul_kernel_with_param<scalar_t, KERNEL<scalar_t>> \
<<<grid, block, 0, stream>>>(out.data_ptr<scalar_t>(), \
input.data_ptr<scalar_t>(), d, \
PARAM); \
});
// #define LAUNCH_ACTIVATION_GATE_KERNEL_WITH_PARAM(KERNEL, PARAM) \
// int d = input.size(-1) / 2; \
// int64_t num_tokens = input.numel() / input.size(-1); \
// dim3 grid(num_tokens); \
// dim3 block(std::min(d, 1024)); \
// const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \
// const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
// VLLM_DISPATCH_FLOATING_TYPES( \
// input.scalar_type(), "act_and_mul_kernel_with_param", [&] { \
// vllm::act_and_mul_kernel_with_param<scalar_t, KERNEL<scalar_t>> \
// <<<grid, block, 0, stream>>>(out.data_ptr<scalar_t>(), \
// input.data_ptr<scalar_t>(), d, \
// PARAM); \
// });
#define LAUNCH_ACTIVATION_GATE_KERNEL(KERNEL) \
......@@ -200,8 +200,8 @@ void gelu_tanh_and_mul_opt(torch::Tensor& out, // [..., d]
LAUNCH_ACTIVATION_GATE_KERNEL(vllm::gelu_tanh_kernel);
}
void fatrelu_and_mul(torch::Tensor& out, // [..., d],
torch::Tensor& input, // [..., 2 * d]
double threshold) {
LAUNCH_ACTIVATION_GATE_KERNEL_WITH_PARAM(vllm::fatrelu_kernel, threshold);
}
\ No newline at end of file
// void fatrelu_and_mul_opt(torch::Tensor& out, // [..., d],
// torch::Tensor& input, // [..., 2 * d]
// double threshold) {
// LAUNCH_ACTIVATION_GATE_KERNEL_WITH_PARAM(vllm::fatrelu_kernel, threshold);
// }
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment