"examples/basics/multinode/trtllm/srun_aggregated.sh" did not exist on "5bf23d54f3e46a15ff5000773a32d8829befa919"
scaled_mm_c3x.cu 3.49 KB
Newer Older
1
2
3
4
#include <cudaTypedefs.h>

#if defined CUDA_VERSION && CUDA_VERSION >= 12000

5
6
  #include "scaled_mm_c3x_sm90_fp8_dispatch.cuh"
  #include "scaled_mm_c3x_sm90_int8_dispatch.cuh"
7

8
  #include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
9
using namespace vllm;
10
11

/*
12
   This file defines quantized GEMM operations using the CUTLASS 3.x API, for
13
14
15
   NVIDIA GPUs with sm90a (Hopper) or later.
*/

16
17
18
19
20
template <template <typename, typename, typename> typename Epilogue,
          typename... EpilogueArgs>
void cutlass_scaled_mm_sm90_epilogue(torch::Tensor& out, torch::Tensor const& a,
                                     torch::Tensor const& b,
                                     EpilogueArgs&&... epilogue_args) {
21
22
23
24
  if (a.dtype() == torch::kInt8) {
    TORCH_CHECK(b.dtype() == torch::kInt8);

    if (out.dtype() == torch::kBFloat16) {
25
      return cutlass_gemm_sm90_int8_dispatch<int8_t, cutlass::bfloat16_t,
26
27
                                             Epilogue>(
          out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
28
29
    } else {
      TORCH_CHECK(out.dtype() == torch::kFloat16);
30
31
      return cutlass_gemm_sm90_int8_dispatch<int8_t, cutlass::half_t, Epilogue>(
          out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
32
33
34
35
36
37
    }
  } else {
    TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
    TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn);

    if (out.dtype() == torch::kBFloat16) {
38
39
40
      return cutlass_gemm_sm90_fp8_dispatch<cutlass::float_e4m3_t,
                                            cutlass::bfloat16_t, Epilogue>(
          out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
41
42
    } else {
      TORCH_CHECK(out.dtype() == torch::kFloat16);
43
      return cutlass_gemm_sm90_fp8_dispatch<cutlass::float_e4m3_t,
44
45
                                            cutlass::half_t, Epilogue>(
          out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
46
47
48
    }
  }
}
49

50
51
52
53
void cutlass_scaled_mm_sm90(torch::Tensor& c, torch::Tensor const& a,
                            torch::Tensor const& b,
                            torch::Tensor const& a_scales,
                            torch::Tensor const& b_scales,
54
                            std::optional<torch::Tensor> const& bias) {
55
56
57
58
59
  TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
  TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
  if (bias) {
    TORCH_CHECK(bias->dtype() == c.dtype(),
                "currently bias dtype must match output dtype ", c.dtype());
60
    return cutlass_scaled_mm_sm90_epilogue<c3x::ScaledEpilogueBias>(
61
62
        c, a, b, a_scales, b_scales, *bias);
  } else {
63
64
    return cutlass_scaled_mm_sm90_epilogue<c3x::ScaledEpilogue>(
        c, a, b, a_scales, b_scales);
65
66
67
  }
}

68
69
70
71
72
void cutlass_scaled_mm_azp_sm90(torch::Tensor& out, torch::Tensor const& a,
                                torch::Tensor const& b,
                                torch::Tensor const& a_scales,
                                torch::Tensor const& b_scales,
                                torch::Tensor const& azp_adj,
73
74
                                std::optional<torch::Tensor> const& azp,
                                std::optional<torch::Tensor> const& bias) {
75
76
77
78
  TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
  TORCH_CHECK(b_scales.dtype() == torch::kFloat32);

  if (azp) {
79
    return cutlass_scaled_mm_sm90_epilogue<c3x::ScaledEpilogueBiasAzpToken>(
80
81
        out, a, b, a_scales, b_scales, azp_adj, *azp, bias);
  } else {
82
    return cutlass_scaled_mm_sm90_epilogue<c3x::ScaledEpilogueBiasAzp>(
83
84
85
86
        out, a, b, a_scales, b_scales, azp_adj, bias);
  }
}

87
#endif