gemm.cpp 10.5 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
/*************************************************************************
 * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
 *
 * See LICENSE for license information.
 ************************************************************************/
#include "transformer_engine/gemm.h"

#include <memory>

#include "common/util/cuda_runtime.h"
#include "common/util/system.h"
#include "extensions.h"
#include "xla/ffi/api/c_api.h"

namespace transformer_engine {
namespace jax {

constexpr static size_t MXFP8_BLOCK_SIZE = 32;

// Note: we only support TN-GEMM for now (TN in cuBLASLt == NT in JAX)
Error_Type GroupedGemmImpl(uint8_t *lhs_ptr, const DType &lhs_dtype, uint8_t *lhs_sinv_ptr,
                           const DType &lhs_sinv_dtype, uint8_t *rhs_ptr, const DType &rhs_dtype,
                           uint8_t *rhs_sinv_ptr, const DType &rhs_sinv_dtype, uint8_t *bias_ptr,
                           const DType &bias_dtype, uint8_t *out_ptr, const DType &out_dtype,
                           uint8_t *workspace_ptr, const size_t workspace_size, size_t num_gemms,
26
                           int32_t *dim_list_ptr, const JAXX_Scaling_Mode scaling_mode,
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
                           cudaStream_t stream) {
  size_t lhs_dtype_bytes = te_dtype_bytes(lhs_dtype);
  size_t rhs_dtype_bytes = te_dtype_bytes(rhs_dtype);
  size_t lhs_sinv_dtype_bytes = te_dtype_bytes(lhs_sinv_dtype);
  size_t rhs_sinv_dtype_bytes = te_dtype_bytes(rhs_sinv_dtype);
  size_t bias_dtype_bytes = te_dtype_bytes(bias_dtype);
  size_t out_dtype_bytes = te_dtype_bytes(out_dtype);
  NVTE_CHECK(lhs_dtype_bytes == rhs_dtype_bytes, "sizeof(lhs_dtype) != sizeof(rhs_dtype)");
  NVTE_CHECK(lhs_sinv_dtype_bytes == rhs_sinv_dtype_bytes,
             "sizeof(lhs_sinv_dtype) != sizeof(rhs_sinv_dtype)");

  size_t dim_list_bytes = sizeof(int32_t) * 3 * num_gemms;
  std::unique_ptr<int32_t[]> dim_list_host = std::make_unique<int32_t[]>(3 * num_gemms);

  cudaMemcpyAsync(dim_list_host.get(), dim_list_ptr, dim_list_bytes, cudaMemcpyDeviceToHost,
                  stream);
  // Note: This may break cudaGraph.
  cudaStreamSynchronize(stream);

  // Notes on matrix layouts and transpose:
47
  // Jax uses row-major data_layout, on entering this function, each input matrix pair:
48
49
50
51
  //   A: row-major with size [m, k],
  //   B: row-major with size [n, k], needs transpose,
  // on exiting this function, JAX expect:
  //   C: row-major with size [m, n].
52
  // cuBLAS uses column-major data_layout, in this view, each input matrix pair:
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
  //   A: column-major with size [k, m], needs transpose,
  //   B: column-major with size [k, n].
  // If we call cuBLAS GEMM for A * B, the output will be:
  //   C: column-major with size [m, n] --> row-major with size [n, m].
  // To make the output compatible with JAX, we need to swap A and B in cuBLAS GEMM call.

  bool trans_lhs = true;
  bool trans_rhs = false;
  auto num_math_sm = cuda::sm_count() - getenv<int>("NVTE_EXT_MARGIN_SM", 0);
  bool grad = false;
  bool accumulate = false;
  bool use_split_accumulator = false;

  // These lists are to keep the TensorWrapper objects alive
  std::vector<TensorWrapper> lhs_wrapper_list;
  std::vector<TensorWrapper> rhs_wrapper_list;
  std::vector<TensorWrapper> bias_wrapper_list;
  std::vector<TensorWrapper> pre_gelu_wrapper_list;
  std::vector<TensorWrapper> out_wrapper_list;
  std::vector<TensorWrapper> workspace_wrapper_list;

  // These lists are the actual NVTETensor (void *) lists for multi-stream GEMM
  std::vector<NVTETensor> lhs_list;
  std::vector<NVTETensor> rhs_list;
  std::vector<NVTETensor> bias_list;
  std::vector<NVTETensor> pre_gelu_list;
  std::vector<NVTETensor> out_list;
  std::vector<NVTETensor> workspace_list;

  for (int i = 0; i < num_gemms; i++) {
    size_t m = dim_list_host[i * 3];
    size_t n = dim_list_host[i * 3 + 1];
    size_t k = dim_list_host[i * 3 + 2];

    auto lhs_shape = std::vector<size_t>{m, k};
    auto rhs_shape = std::vector<size_t>{n, k};
    auto out_shape = std::vector<size_t>{n, m};
    auto lhs_sinv_shape = std::vector<size_t>{1, 1};
    auto rhs_sinv_shape = std::vector<size_t>{1, 1};

93
94
95
96
97
98
99
100
101
102
103
    auto lhs_i = TensorWrapper(get_nvte_scaling_mode(scaling_mode));
    auto rhs_i = TensorWrapper(get_nvte_scaling_mode(scaling_mode));
    lhs_i.set_rowwise_data(static_cast<void *>(lhs_ptr), lhs_dtype, lhs_shape);
    rhs_i.set_rowwise_data(static_cast<void *>(rhs_ptr), rhs_dtype, rhs_shape);

    if (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING) {
      lhs_i.set_rowwise_scale_inv(static_cast<void *>(lhs_sinv_ptr), DType::kFloat32,
                                  std::vector<size_t>{1});
      rhs_i.set_rowwise_scale_inv(static_cast<void *>(rhs_sinv_ptr), DType::kFloat32,
                                  std::vector<size_t>{1});
    } else if (scaling_mode == JAXX_Scaling_Mode::MXFP8_1D_SCALING) {
104
105
106
107
108
109
110
111
112
113
114
115
116
117
      NVTE_CHECK(k % MXFP8_BLOCK_SIZE == 0, "MXFP8 K-dim being divisble by %d (got %d)",
                 MXFP8_BLOCK_SIZE, k);
      size_t sinv_k = k / MXFP8_BLOCK_SIZE;
      lhs_sinv_shape[0] = m;
      lhs_sinv_shape[1] = sinv_k;
      rhs_sinv_shape[0] = n;
      rhs_sinv_shape[1] = sinv_k;

      // Note: the scale_inv array should have been swizzled in Python before lowering
      lhs_i.set_rowwise_scale_inv(static_cast<void *>(lhs_sinv_ptr), DType::kFloat8E8M0,
                                  lhs_sinv_shape);
      rhs_i.set_rowwise_scale_inv(static_cast<void *>(rhs_sinv_ptr), DType::kFloat8E8M0,
                                  rhs_sinv_shape);
    } else {
118
      NVTE_ERROR("Unsupported scaling mode: ", static_cast<int>(scaling_mode));
119
    }
120
121
    lhs_wrapper_list.push_back(std::move(lhs_i));
    rhs_wrapper_list.push_back(std::move(rhs_i));
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169

    auto out_i = TensorWrapper(static_cast<void *>(out_ptr), out_shape, out_dtype);
    lhs_ptr += m * k * lhs_dtype_bytes;
    rhs_ptr += n * k * rhs_dtype_bytes;
    out_ptr += m * n * out_dtype_bytes;
    lhs_sinv_ptr += lhs_sinv_shape[0] * lhs_sinv_shape[1] * lhs_sinv_dtype_bytes;
    rhs_sinv_ptr += rhs_sinv_shape[0] * rhs_sinv_shape[1] * rhs_sinv_dtype_bytes;

    void *pre_gelu_ptr = nullptr;
    auto bias_shape = std::vector<size_t>{0};
    auto pre_gelu_shape = std::vector<size_t>{0};
    if (bias_ptr != nullptr) bias_shape[0] = n;
    auto bias_i = TensorWrapper(bias_ptr, bias_shape, bias_dtype);
    if (bias_ptr != nullptr) bias_ptr += n * bias_dtype_bytes;
    auto pre_gelu_i = TensorWrapper(pre_gelu_ptr, pre_gelu_shape, out_dtype);

    out_wrapper_list.push_back(std::move(out_i));
    bias_wrapper_list.push_back(std::move(bias_i));
    pre_gelu_wrapper_list.push_back(std::move(pre_gelu_i));

    lhs_list.push_back(lhs_wrapper_list.back().data());
    rhs_list.push_back(rhs_wrapper_list.back().data());
    bias_list.push_back(bias_wrapper_list.back().data());
    pre_gelu_list.push_back(pre_gelu_wrapper_list.back().data());
    out_list.push_back(out_wrapper_list.back().data());
  }

  auto workspace_shape = std::vector<size_t>{workspace_size};
  for (int i = 0; i < num_streams; i++) {
    auto workspace_i =
        TensorWrapper(static_cast<void *>(workspace_ptr), workspace_shape, DType::kByte);
    workspace_wrapper_list.push_back(std::move(workspace_i));
    workspace_list.push_back(workspace_wrapper_list.back().data());
    workspace_ptr += workspace_size;
  }

  nvte_multi_stream_cublas_gemm(rhs_list.data(), lhs_list.data(), out_list.data(), bias_list.data(),
                                pre_gelu_list.data(), num_gemms, trans_lhs, trans_rhs, grad,
                                workspace_list.data(), accumulate, use_split_accumulator,
                                num_math_sm, stream);

  return ffi_with_cuda_error_check();
}

Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_flatten,
                          Buffer_Type lhs_sinv_flatten, Buffer_Type rhs_flatten,
                          Buffer_Type rhs_sinv_flatten, Buffer_Type bias_flatten,
                          Buffer_Type dim_list, Result_Type out_flatten,
170
171
                          Result_Type workspace_flatten, int64_t num_gemms,
                          JAXX_Scaling_Mode scaling_mode) {
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
  // Inputs
  auto lhs_ptr = reinterpret_cast<uint8_t *>(lhs_flatten.untyped_data());
  auto rhs_ptr = reinterpret_cast<uint8_t *>(rhs_flatten.untyped_data());
  auto lhs_sinv_ptr = reinterpret_cast<uint8_t *>(lhs_sinv_flatten.untyped_data());
  auto rhs_sinv_ptr = reinterpret_cast<uint8_t *>(rhs_sinv_flatten.untyped_data());
  auto bias_ptr = reinterpret_cast<uint8_t *>(bias_flatten.untyped_data());
  auto dim_list_ptr = reinterpret_cast<int32_t *>(dim_list.untyped_data());
  auto lhs_dtype = convert_ffi_datatype_to_te_dtype(lhs_flatten.element_type());
  auto rhs_dtype = convert_ffi_datatype_to_te_dtype(rhs_flatten.element_type());
  auto lhs_sinv_dtype = convert_ffi_datatype_to_te_dtype(lhs_sinv_flatten.element_type());
  auto rhs_sinv_dtype = convert_ffi_datatype_to_te_dtype(rhs_sinv_flatten.element_type());
  auto bias_dtype = convert_ffi_datatype_to_te_dtype(bias_flatten.element_type());

  // Outputs
  auto out_ptr = reinterpret_cast<uint8_t *>(out_flatten->untyped_data());
  auto out_dtype = convert_ffi_datatype_to_te_dtype(out_flatten->element_type());
  auto workspace_ptr = reinterpret_cast<uint8_t *>(workspace_flatten->untyped_data());
  auto workspace_size = workspace_flatten->dimensions().back() / num_streams;

  return GroupedGemmImpl(lhs_ptr, lhs_dtype, lhs_sinv_ptr, lhs_sinv_dtype, rhs_ptr, rhs_dtype,
                         rhs_sinv_ptr, rhs_sinv_dtype, bias_ptr, bias_dtype, out_ptr, out_dtype,
                         workspace_ptr, workspace_size, num_gemms, dim_list_ptr, scaling_mode,
                         stream);
}

XLA_FFI_DEFINE_HANDLER_SYMBOL(GroupedGemmHandler, GroupedGemmFFI,
                              FFI::Bind()
                                  .Ctx<FFI_Stream_Type>()  // stream
                                  .Arg<Buffer_Type>()      // lhs_flatten
                                  .Arg<Buffer_Type>()      // lhs_sinv_flatten
                                  .Arg<Buffer_Type>()      // rhs_flatten
                                  .Arg<Buffer_Type>()      // rhs_sinv_flatten
                                  .Arg<Buffer_Type>()      // bias_flatten
                                  .Arg<Buffer_Type>()      // dim_list
                                  .Ret<Buffer_Type>()      // out_flatten
                                  .Ret<Buffer_Type>()      // workspace_flatten
                                  .Attr<int64_t>("num_gemms")
209
                                  .Attr<JAXX_Scaling_Mode>("scaling_mode"),
210
211
212
213
                              FFI_CudaGraph_Traits);

}  // namespace jax
}  // namespace transformer_engine