/* * Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include #include #include #include "torch/csrc/cuda/Stream.h" #include #include #include "src/fastertransformer/kernels/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm.h" #include "src/fastertransformer/th_op/th_utils.h" #include "src/fastertransformer/utils/cuda_bf16_wrapper.h" #include "cutlass/numeric_types.h" using torch::Tensor; namespace torch_ext { namespace ft = fastertransformer; template Tensor fused_gemm_dq_helper( Tensor input_activations, Tensor weight, Tensor scales, const int64_t timing_iterations, float& avg_time) { const at::ScalarType _st = input_activations.scalar_type(); const int m = input_activations.size(0); const int n = scales.size(0); const int k = input_activations.size(1); auto stream = at::cuda::getCurrentCUDAStream().stream(); const T* input_act_ptr = get_ptr(input_activations); const WeightType* weight_ptr = get_ptr(weight); const T* scales_ptr = get_ptr(scales); fastertransformer::CutlassFpAIntBGemmRunner fused_gemm_dq_runner; const int ws_bytes = fused_gemm_dq_runner.getWorkspaceSize(m, n, k); auto output_tensor = torch::empty({m, n}, torch::dtype(_st).device(torch::kCUDA).requires_grad(false)); auto ws_tensor = torch::empty({ws_bytes}, torch::dtype(torch::kInt8).device(torch::kCUDA).requires_grad(false)); T* output_tensor_ptr = get_ptr(output_tensor); char* ws_ptr = get_ptr(ws_tensor); cudaEvent_t start, stop; cudaEventCreate(&start); cudaEventCreate(&stop); cudaEventRecord(start, stream); for (int64_t iter = 0; iter < timing_iterations; ++iter) { fused_gemm_dq_runner.gemm( input_act_ptr, weight_ptr, scales_ptr, output_tensor_ptr, m, n, k, ws_ptr, ws_bytes, stream); } cudaEventRecord(stop, stream); cudaEventSynchronize(stop); float total_time_ms = 0; cudaEventElapsedTime(&total_time_ms, start, stop); avg_time = total_time_ms / float(timing_iterations); return output_tensor; } Tensor _fused_gemm_dq(Tensor input_activations, Tensor weight, Tensor scales, int64_t timing_iterations, float& avg_time) { const at::ScalarType _st = input_activations.scalar_type(); CHECK_INPUT(scales, _st); TORCH_CHECK(input_activations.dim() == 2, "Invalid rank for activations"); TORCH_CHECK(weight.dim() == 2, "Invalid rank for weight"); TORCH_CHECK(scales.dim() == 1, "Invalid rank for scales"); const int m = input_activations.size(0); const int n = scales.size(0); const int k = input_activations.size(1); TORCH_CHECK(input_activations.size(1) == weight.size(0), "dim 1 of act and dim 0 of weight must be equal"); // We signal int4 by having the last weight dim be half the size of the scales. // This is because int4 elements are packed into a single byte. torch::ScalarType quant_type = weight.scalar_type(); if (weight.size(-1) == scales.size(-1) / 2) { quant_type = at::ScalarType::QUInt4x2; } else { TORCH_CHECK(weight.size(-1) == scales.size(-1), "Last dim of weight and scales must be equal for int8 " "or last dim of scale must be 2x last dim of weight for int4."); } Tensor output_tensor; switch (_st) { case at::ScalarType::Half: { if (quant_type == torch::kInt8) { output_tensor = fused_gemm_dq_helper(input_activations, weight, scales, timing_iterations, avg_time); } else if (quant_type == at::ScalarType::QUInt4x2) { output_tensor = fused_gemm_dq_helper( input_activations, weight, scales, timing_iterations, avg_time); } else { std::string err_msg = "Unsupported weight type " + std::string(at::toString(quant_type)); throw std::runtime_error(err_msg); } break; } #ifdef ENABLE_BF16 case at::ScalarType::BFloat16: { if (quant_type == torch::kInt8) { output_tensor = fused_gemm_dq_helper<__nv_bfloat16, uint8_t>( input_activations, weight, scales, timing_iterations, avg_time); } else if (quant_type == at::ScalarType::QUInt4x2) { output_tensor = fused_gemm_dq_helper<__nv_bfloat16, cutlass::uint4b_t>( input_activations, weight, scales, timing_iterations, avg_time); } else { std::string err_msg = "Unsupported weight type " + std::string(at::toString(quant_type)); throw std::runtime_error(err_msg); } break; } #endif default: throw std::runtime_error("Unsupported tensor type. Got " + std::string(at::toString(_st))); } return output_tensor; } Tensor fused_gemm_dq(Tensor input_activations, Tensor weight, Tensor scales) { float dummy = 0.f; return _fused_gemm_dq(input_activations, weight, scales, 1, dummy); } Tensor bench_cublas(Tensor input_activations, Tensor weight_dequantized, const int64_t timing_iterations, float& avg_time) { using namespace fastertransformer; const int m = input_activations.size(0); const int n = weight_dequantized.size(1); const int k = input_activations.size(1); const void* input_act_ptr = get_ptr(input_activations); const void* weight_ptr = get_ptr(weight_dequantized); cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle(); const at::ScalarType _st = input_activations.scalar_type(); TORCH_CHECK(input_activations.size(1) == weight_dequantized.size(0), "CUBLAS_BENCH: dim 1 of act and dim 0 of weight must be equal"); CHECK_INPUT(input_activations, _st); CHECK_INPUT(weight_dequantized, _st); auto output_tensor = torch::empty({m, n}, torch::dtype(_st).device(torch::kCUDA).requires_grad(false)); void* output_tensor_ptr = get_ptr(output_tensor); TORCH_CHECK(_st == at::ScalarType::Half || _st == at::ScalarType::BFloat16, "Input type must be float or bfloat"); cudaDataType_t cublasType = _st == at::ScalarType::Half ? CUDA_R_16F : CUDA_R_16BF; float alpha = 1.0f; float beta = 0.0f; auto stream = at::cuda::getCurrentCUDAStream().stream(); cublasSetStream(handle, stream); cudaEvent_t start, stop; cudaEventCreate(&start); cudaEventCreate(&stop); cublasStatus_t status = CUBLAS_STATUS_SUCCESS; cudaEventRecord(start, stream); for (int64_t iter = 0; iter < timing_iterations; ++iter) { status = cublasGemmEx(handle, CUBLAS_OP_N, CUBLAS_OP_N, n, m, k, &alpha, weight_ptr, cublasType, n, input_act_ptr, cublasType, k, &beta, output_tensor_ptr, cublasType, n, CUBLAS_COMPUTE_32F, CUBLAS_GEMM_DEFAULT); } cudaEventRecord(stop, stream); cudaEventSynchronize(stop); float total_time_ms = 0; cudaEventElapsedTime(&total_time_ms, start, stop); avg_time = total_time_ms / float(timing_iterations); check_cuda_error(status); return output_tensor; } std::vector> benchmark_against_cublas_fp(Tensor input_activations, Tensor weight_quantized, Tensor scales, Tensor weight_dequantized, const int64_t timing_iterations) { float cublas_time = 0.f; float ft_time = 0.f; Tensor cublas_result = bench_cublas(input_activations, weight_dequantized, timing_iterations, cublas_time); Tensor ft_result = _fused_gemm_dq(input_activations, weight_quantized, scales, timing_iterations, ft_time); auto timing_tensor = torch::empty({2}, torch::dtype(at::ScalarType::Float).device(torch::kCPU).requires_grad(false)); timing_tensor[0] = cublas_time; timing_tensor[1] = ft_time; // const int m = input_activations.size(0); // const int n = weight_dequantized.size(1); // const int k = input_activations.size(1); // std::cout << "m, n, k" << m << ", " << n << ", " << k << std::endl; // std::cout << "cuBLAS time (ms) " << cublas_time << std::endl; // std::cout << "FT time (ms) " << ft_time << std::endl; return {{timing_tensor}, {cublas_result, ft_result}}; } template Tensor fused_gemm_dq_bias_act_helper( Tensor input_activations, Tensor weight, Tensor scales, Tensor bias, ft::ActivationType activation_type) { const at::ScalarType _st = input_activations.scalar_type(); const int m = input_activations.size(0); const int n = scales.size(0); const int k = input_activations.size(1); auto stream = at::cuda::getCurrentCUDAStream().stream(); const T* input_act_ptr = get_ptr(input_activations); const WeightType* weight_ptr = get_ptr(weight); const T* scales_ptr = get_ptr(scales); const T* bias_ptr = get_ptr(bias); fastertransformer::CutlassFpAIntBGemmRunner fused_gemm_dq_runner; const int ws_bytes = fused_gemm_dq_runner.getWorkspaceSize(m, n, k); auto output_tensor = torch::empty({m, n}, torch::dtype(_st).device(torch::kCUDA).requires_grad(false)); auto ws_tensor = torch::empty({ws_bytes}, torch::dtype(torch::kInt8).device(torch::kCUDA).requires_grad(false)); T* output_tensor_ptr = get_ptr(output_tensor); char* ws_ptr = get_ptr(ws_tensor); fused_gemm_dq_runner.gemm_bias_act(input_act_ptr, weight_ptr, scales_ptr, bias_ptr, output_tensor_ptr, m, n, k, activation_type, ws_ptr, ws_bytes, stream); return output_tensor; } Tensor fused_gemm_dq_bias_act( Tensor input_activations, Tensor weight, Tensor scales, Tensor bias, std::string activation_type_str) { const at::ScalarType _st = input_activations.scalar_type(); CHECK_INPUT(scales, _st); CHECK_INPUT(bias, _st); TORCH_CHECK(input_activations.dim() == 2, "Invalid rank for activations"); TORCH_CHECK(weight.dim() == 2, "Invalid rank for weight"); TORCH_CHECK(scales.dim() == 1, "Invalid rank for scales"); TORCH_CHECK(bias.dim() == 1, "Invalid rank for bias"); const int m = input_activations.size(0); const int n = scales.size(0); const int k = input_activations.size(1); TORCH_CHECK(bias.size(0) == n, "Must have 1 bias value for each output column"); TORCH_CHECK(input_activations.size(1) == weight.size(0), "dim 1 of act and dim 0 of weight must be equal"); // We signal int4 by having the last weight dim be half the size of the scales. // This is because int4 elements are packed into a single byte. torch::ScalarType quant_type = weight.scalar_type(); if (weight.size(-1) == scales.size(-1) / 2) { quant_type = at::ScalarType::QUInt4x2; } else { TORCH_CHECK(weight.size(-1) == scales.size(-1), "Last dim of weight and scales must be equal for int8 " "or last dim of scale must be 2x last dim of weight for int4."); } ft::ActivationType activation_type = ft::ActivationType::InvalidType; if (activation_type_str == "identity") { activation_type = ft::ActivationType::Identity; } else { activation_type = ft::getActivationType(activation_type_str); } TORCH_CHECK(!isGatedActivation(activation_type), "Fused gated activations not supported."); Tensor output_tensor; switch (_st) { case at::ScalarType::Half: { if (quant_type == torch::kInt8) { output_tensor = fused_gemm_dq_bias_act_helper( input_activations, weight, scales, bias, activation_type); } else if (quant_type == at::ScalarType::QUInt4x2) { output_tensor = fused_gemm_dq_bias_act_helper( input_activations, weight, scales, bias, activation_type); } else { std::string err_msg = "Unsupported weight type " + std::string(at::toString(quant_type)); throw std::runtime_error(err_msg); } break; } #ifdef ENABLE_BF16 case at::ScalarType::BFloat16: { if (quant_type == torch::kInt8) { output_tensor = fused_gemm_dq_bias_act_helper<__nv_bfloat16, uint8_t>( input_activations, weight, scales, bias, activation_type); } else if (quant_type == at::ScalarType::QUInt4x2) { output_tensor = fused_gemm_dq_bias_act_helper<__nv_bfloat16, cutlass::uint4b_t>( input_activations, weight, scales, bias, activation_type); } else { std::string err_msg = "Unsupported weight type " + std::string(at::toString(quant_type)); throw std::runtime_error(err_msg); } break; } #endif default: throw std::runtime_error("Unsupported tensor type. Got " + std::string(at::toString(_st))); } return output_tensor; } TORCH_LIBRARY(gemm_dq_unit_ops, m) { m.def("fused_gemm_dq", fused_gemm_dq); m.def("benchmark_against_cublas_fp", benchmark_against_cublas_fp); m.def("fused_gemm_dq_bias_act", fused_gemm_dq_bias_act); } } // namespace torch_ext