Commit 6dc7aa42 authored by zhuwenwen's avatar zhuwenwen
Browse files

remove fatrelu_and_mul

parent 02c3e313
...@@ -107,41 +107,41 @@ __device__ __forceinline__ T gelu_tanh_kernel(const T& x) { ...@@ -107,41 +107,41 @@ __device__ __forceinline__ T gelu_tanh_kernel(const T& x) {
return (T)(0.5f * f * (1.0f + ::tanhf(inner))); return (T)(0.5f * f * (1.0f + ::tanhf(inner)));
} }
template <typename T> // template <typename T>
__device__ __forceinline__ T fatrelu_kernel(const T& x, const float threshold) { // __device__ __forceinline__ T fatrelu_kernel(const T& x, const float threshold) {
const float f = (float)x; // const float f = (float)x;
return (T)(f > threshold ? f : 0.0f); // return (T)(f > threshold ? f : 0.0f);
} // }
template <typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&, const float)> // template <typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&, const float)>
__global__ void act_and_mul_kernel_with_param( // __global__ void act_and_mul_kernel_with_param(
scalar_t* __restrict__ out, const scalar_t* __restrict__ input, const int d, // scalar_t* __restrict__ out, const scalar_t* __restrict__ input, const int d,
const float param) { // const float param) {
const int64_t token_idx = blockIdx.x; // const int64_t token_idx = blockIdx.x;
for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) { // for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
const scalar_t x = VLLM_LDG(&input[token_idx * 2 * d + idx]); // const scalar_t x = VLLM_LDG(&input[token_idx * 2 * d + idx]);
const scalar_t y = VLLM_LDG(&input[token_idx * 2 * d + d + idx]); // const scalar_t y = VLLM_LDG(&input[token_idx * 2 * d + d + idx]);
out[token_idx * d + idx] = ACT_FN(x, param) * y; // out[token_idx * d + idx] = ACT_FN(x, param) * y;
} // }
} // }
} // namespace vllm } // namespace vllm
#define LAUNCH_ACTIVATION_GATE_KERNEL_WITH_PARAM(KERNEL, PARAM) \ // #define LAUNCH_ACTIVATION_GATE_KERNEL_WITH_PARAM(KERNEL, PARAM) \
int d = input.size(-1) / 2; \ // int d = input.size(-1) / 2; \
int64_t num_tokens = input.numel() / input.size(-1); \ // int64_t num_tokens = input.numel() / input.size(-1); \
dim3 grid(num_tokens); \ // dim3 grid(num_tokens); \
dim3 block(std::min(d, 1024)); \ // dim3 block(std::min(d, 1024)); \
const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \ // const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \ // const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
VLLM_DISPATCH_FLOATING_TYPES( \ // VLLM_DISPATCH_FLOATING_TYPES( \
input.scalar_type(), "act_and_mul_kernel_with_param", [&] { \ // input.scalar_type(), "act_and_mul_kernel_with_param", [&] { \
vllm::act_and_mul_kernel_with_param<scalar_t, KERNEL<scalar_t>> \ // vllm::act_and_mul_kernel_with_param<scalar_t, KERNEL<scalar_t>> \
<<<grid, block, 0, stream>>>(out.data_ptr<scalar_t>(), \ // <<<grid, block, 0, stream>>>(out.data_ptr<scalar_t>(), \
input.data_ptr<scalar_t>(), d, \ // input.data_ptr<scalar_t>(), d, \
PARAM); \ // PARAM); \
}); // });
#define LAUNCH_ACTIVATION_GATE_KERNEL(KERNEL) \ #define LAUNCH_ACTIVATION_GATE_KERNEL(KERNEL) \
...@@ -200,8 +200,8 @@ void gelu_tanh_and_mul_opt(torch::Tensor& out, // [..., d] ...@@ -200,8 +200,8 @@ void gelu_tanh_and_mul_opt(torch::Tensor& out, // [..., d]
LAUNCH_ACTIVATION_GATE_KERNEL(vllm::gelu_tanh_kernel); LAUNCH_ACTIVATION_GATE_KERNEL(vllm::gelu_tanh_kernel);
} }
void fatrelu_and_mul(torch::Tensor& out, // [..., d], // void fatrelu_and_mul_opt(torch::Tensor& out, // [..., d],
torch::Tensor& input, // [..., 2 * d] // torch::Tensor& input, // [..., 2 * d]
double threshold) { // double threshold) {
LAUNCH_ACTIVATION_GATE_KERNEL_WITH_PARAM(vllm::fatrelu_kernel, threshold); // LAUNCH_ACTIVATION_GATE_KERNEL_WITH_PARAM(vllm::fatrelu_kernel, threshold);
} // }
\ No newline at end of file \ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment