common.h 19.7 KB
Newer Older
Przemek Tredak's avatar
Przemek Tredak committed
1
/*************************************************************************
2
 * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
Przemek Tredak's avatar
Przemek Tredak committed
3
4
5
6
7
8
9
10
 *
 * See LICENSE for license information.
 ************************************************************************/

#ifndef TRANSFORMER_ENGINE_PYTORCH_CSRC_COMMON_H_
#define TRANSFORMER_ENGINE_PYTORCH_CSRC_COMMON_H_

#include <ATen/ATen.h>
cyanguwa's avatar
cyanguwa committed
11
#include <ATen/Dispatch.h>
Tim Moon's avatar
Tim Moon committed
12
#include <ATen/cuda/CUDAContext.h>
cyanguwa's avatar
cyanguwa committed
13
#include <ATen/cuda/CUDAGeneratorImpl.h>
Tim Moon's avatar
Tim Moon committed
14
15
16
#include <ATen/cudnn/Handle.h>
#include <ATen/native/DispatchStub.h>
#include <c10/macros/Macros.h>
17
18
#include <c10/util/Float8_e4m3fn.h>
#include <c10/util/Float8_e5m2.h>
yuguo's avatar
yuguo committed
19
20
#include <cuda_runtime.h>
#ifndef USE_ROCM
Tim Moon's avatar
Tim Moon committed
21
#include <cublasLt.h>
Przemek Tredak's avatar
Przemek Tredak committed
22
#include <cuda.h>
23
#include <cudnn.h>
yuguo's avatar
yuguo committed
24
25
26
27
#include <cuda_bf16.h>
#else
#include <hip/hip_bf16.h>
#endif
Tim Moon's avatar
Tim Moon committed
28
29
30
31
#include <torch/extension.h>
#include <torch/torch.h>
#include <transformer_engine/activation.h>
#include <transformer_engine/cast.h>
32
#include <transformer_engine/cast_transpose_noop.h>
33
#include <transformer_engine/comm_gemm_overlap.h>
Tim Moon's avatar
Tim Moon committed
34
#include <transformer_engine/fused_attn.h>
35
#include <transformer_engine/fused_rope.h>
36
#include <transformer_engine/fused_router.h>
Tim Moon's avatar
Tim Moon committed
37
#include <transformer_engine/gemm.h>
38
#include <transformer_engine/hadamard_transform.h>
39
#include <transformer_engine/multi_stream.h>
40
#include <transformer_engine/multi_tensor.h>
41
#include <transformer_engine/normalization.h>
42
#include <transformer_engine/padding.h>
43
#include <transformer_engine/permutation.h>
44
#include <transformer_engine/recipe.h>
Tim Moon's avatar
Tim Moon committed
45
#include <transformer_engine/softmax.h>
46
#include <transformer_engine/swizzle.h>
Tim Moon's avatar
Tim Moon committed
47
48
#include <transformer_engine/transformer_engine.h>
#include <transformer_engine/transpose.h>
49
50

#include <ATen/cuda/CUDAGraphsUtils.cuh>
51
#include <cassert>
52
53
#include <cstring>
#include <iostream>
54
55
#include <string>
#include <sstream>
56
#include <memory>
57
#include <torch/csrc/distributed/c10d/ProcessGroup.hpp>
58
59
#include <vector>

60
#include "c10/util/ArrayRef.h"
61
#include "common/util/logging.h"
Przemek Tredak's avatar
Przemek Tredak committed
62

63
namespace transformer_engine::pytorch {
Przemek Tredak's avatar
Przemek Tredak committed
64

65
66
67
// in python we have: dist_group_type = torch.distributed.ProcessGroup
using dist_group_type = c10d::ProcessGroup;

68
69
70
71
72
73
74
75
76
77
78
79
inline int blockwise_fp8_block_len() {
  const char *env = std::getenv("NVTE_BLOCKWISE_FP8_BLOCK_LEN");
  if (env == nullptr || env[0] == '\0') {
    return 128;
  }
  int value;
  std::istringstream iss(env);
  iss >> value;
  NVTE_CHECK(iss, "Invalid environment variable value");
  return value;
}

Przemek Tredak's avatar
Przemek Tredak committed
80
81
82
83
// Each tensor here is shape (N, ) holding all scaling
// data for a single FP8 block, e.g. LayerNormLinear
class FP8TensorMeta {
 public:
84
85
86
  at::Tensor scale;
  at::Tensor scale_inv;
  at::Tensor amax_history;
Przemek Tredak's avatar
Przemek Tredak committed
87
88
89
90
91
};

// Used as named indices on the `scale`, `scale_inv`,
// and `amax` tensors in the `FP8TensorMeta` class.
enum FP8FwdTensors {
92
93
94
95
96
97
98
99
100
  GEMM1_INPUT = 0,
  GEMM1_WEIGHT = 1,
  GEMM1_OUTPUT = 2,
  GEMM2_INPUT = 3,
  GEMM2_WEIGHT = 4,
  GEMM2_OUTPUT = 5,
  GEMM3_INPUT = 6,
  GEMM3_WEIGHT = 7,
  GEMM3_OUTPUT = 8
Przemek Tredak's avatar
Przemek Tredak committed
101
102
103
104
105
};

// Used as named indices on the `scale`, `scale_inv`,
// and `amax` tensors in the `FP8TensorMeta` class.
enum FP8BwdTensors {
106
107
108
109
110
111
  GRAD_OUTPUT1 = 0,
  GRAD_INPUT1 = 1,
  GRAD_OUTPUT2 = 2,
  GRAD_INPUT2 = 3,
  GRAD_OUTPUT3 = 4,
  GRAD_INPUT3 = 5
Przemek Tredak's avatar
Przemek Tredak committed
112
113
};

114
115
116
117
118
119
class Quantizer {
 public:
  virtual NVTEScalingMode get_scaling_mode() const = 0;

  virtual void set_quantization_params(TensorWrapper* tensor) const = 0;

120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
  /*! @brief Construct a tensor with uninitialized data */
  virtual std::pair<TensorWrapper, py::object> create_tensor(const std::vector<size_t>& shape,
                                                             DType dtype) const = 0;

  /*! @brief Convert a PyTorch tensor into a Transformer Engine C++ tensor
   *
   * The PyTorch tensor's attributes are modified to match the
   * quantizer's configuration.
   */
  virtual std::pair<TensorWrapper, py::object> convert_and_update_tensor(
      py::object tensor) const = 0;

  /*! @brief Convert to a quantized data format */
  virtual void quantize(const TensorWrapper& input, TensorWrapper& out,
                        const std::optional<TensorWrapper>& noop_flag = std::nullopt) = 0;
135
136
137
138
139
140

  virtual ~Quantizer() = default;

  bool rowwise_usage = true;
  bool columnwise_usage = true;
  bool internal = false;
141
  bool optimize_for_gemm = false;
142
143
144
145
146
147
148
149
150
151
152
153
154
155
  py::handle quantizer;

 protected:
  explicit Quantizer(const py::handle& quantizer);
};

class NoneQuantizer : public Quantizer {
 public:
  explicit NoneQuantizer(const py::handle& quantizer) : Quantizer(quantizer) {}

  NVTEScalingMode get_scaling_mode() const override { return NVTE_DELAYED_TENSOR_SCALING; }

  void set_quantization_params(TensorWrapper* tensor) const override {}

156
157
158
159
160
161
162
163
164
165
166
  std::pair<TensorWrapper, py::object> create_tensor(const std::vector<size_t>& shape,
                                                     DType dtype) const override;

  /*! @brief Construct a tensor with pre-initialized data */
  std::pair<TensorWrapper, py::object> create_tensor(const std::vector<size_t>& shape, DType dtype,
                                                     at::Tensor data) const;

  std::pair<TensorWrapper, py::object> convert_and_update_tensor(py::object tensor) const override;

  void quantize(const TensorWrapper& input, TensorWrapper& out,
                const std::optional<TensorWrapper>& noop_flag = std::nullopt) override;
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
};

class Float8Quantizer : public Quantizer {
 public:
  at::Tensor scale;
  at::Tensor scale_inv;
  at::Tensor amax;
  DType dtype;

  explicit Float8Quantizer(const py::handle& quantizer);

  NVTEScalingMode get_scaling_mode() const override { return NVTE_DELAYED_TENSOR_SCALING; }

  void set_quantization_params(TensorWrapper* tensor) const override;

182
183
184
185
186
187
188
189
190
191
192
193
194
  std::pair<TensorWrapper, py::object> create_tensor(const std::vector<size_t>& shape,
                                                     DType dtype) const override;

  /*! @brief Construct a tensor with pre-initialized data */
  std::pair<TensorWrapper, py::object> create_tensor(const std::vector<size_t>& shape, DType dtype,
                                                     std::optional<at::Tensor> data,
                                                     std::optional<at::Tensor> transpose,
                                                     std::optional<at::Tensor> scale_inv) const;

  std::pair<TensorWrapper, py::object> convert_and_update_tensor(py::object shape) const override;

  void quantize(const TensorWrapper& input, TensorWrapper& out,
                const std::optional<TensorWrapper>& noop_flag = std::nullopt) override;
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
};

class Float8CurrentScalingQuantizer : public Quantizer {
 public:
  at::Tensor scale;
  at::Tensor scale_inv;
  at::Tensor amax;
  DType dtype;
  bool with_amax_reduction;
  c10::intrusive_ptr<dist_group_type> amax_reduction_group;
  bool force_pow_2_scales = false;
  float amax_epsilon = 0.0;

  explicit Float8CurrentScalingQuantizer(const py::handle& quantizer);

  NVTEScalingMode get_scaling_mode() const override { return NVTE_DELAYED_TENSOR_SCALING; }

  void set_quantization_params(TensorWrapper* tensor) const override;

214
215
216
  std::pair<TensorWrapper, py::object> create_tensor(const std::vector<size_t>& shape,
                                                     DType dtype) const override;

217
218
219
220
  /*! @brief Construct an unquantized tensor that shares the quantizer's amax pointer.
   *
   * The amax is zeroed out. Most TE kernels that output amax expect
   * amax to be initialized to zero.
221
  */
222
  std::pair<TensorWrapper, py::object> create_unquantized_tensor_with_amax(
223
      const std::vector<size_t>& shape, DType dtype, std::optional<at::Tensor> data = std::nullopt);
224

225
226
227
228
  std::pair<TensorWrapper, py::object> convert_and_update_tensor(py::object shape) const override;

  void quantize(const TensorWrapper& input, TensorWrapper& out,
                const std::optional<TensorWrapper>& noop_flag = std::nullopt) override;
229

230
231
232
233
234
235
  /*! @brief Quantize to FP8, skipping local amax computation
   *
   * The quantizer's amax pointer is assumed to already hold the local
   * amax. The amax may still be reduced across the amax reduction
   * group.
   */
236
237
238
239
240
241
  void quantize_with_amax(TensorWrapper& input, TensorWrapper& out,
                          const std::optional<TensorWrapper>& noop_flag = std::nullopt);

 private:
  void quantize_impl(const TensorWrapper& input, TensorWrapper& out,
                     const std::optional<TensorWrapper>& noop_flag, bool compute_amax);
242
243
244
245
246
247
248
249
250
251
252
};

class Float8BlockQuantizer : public Quantizer {
 public:
  // Which float8 type is used for q data.
  DType dtype;
  // Options about how to quantize the tensor
  // Quantization scales are rounded down to powers of 2.
  bool force_pow_2_scales = false;
  // Amax within quantization tile has a floor of epsilon.
  float amax_epsilon = 0.0;
253
254

 private:
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
  int block_scaling_dim = 2;

 public:
  // Initializes from a python handle to a Float8BlockQuantizer
  explicit Float8BlockQuantizer(const py::handle& quantizer);

  NVTEScalingMode get_scaling_mode() const override {
    return (block_scaling_dim == 2) ? NVTE_BLOCK_SCALING_2D : NVTE_BLOCK_SCALING_1D;
  }

  // Gets rowwise and columnwise_data from tensor and sets them on wrapper
  void set_quantization_params(TensorWrapper* tensor) const override;

  // Create a python Float8BlockQuantized tensor and C++ wrapper
  // for the tensor. Should set quantized data, scales for rowwise
  // and optionally columnwise usage.
271
272
273
274
275
276
277
  std::pair<TensorWrapper, py::object> create_tensor(const std::vector<size_t>& shape,
                                                     DType dtype) const override;

  std::pair<TensorWrapper, py::object> convert_and_update_tensor(py::object shape) const override;

  void quantize(const TensorWrapper& input, TensorWrapper& out,
                const std::optional<TensorWrapper>& noop_flag = std::nullopt) override;
278
279

  std::vector<size_t> get_scale_shape(const std::vector<size_t>& shape, bool columnwise) const;
280
281
282
283
284
285
286
287
288
289
290
291
};

class MXFP8Quantizer : public Quantizer {
 public:
  DType dtype;

  explicit MXFP8Quantizer(const py::handle& quantizer);

  NVTEScalingMode get_scaling_mode() const override { return NVTE_MXFP8_1D_SCALING; }

  void set_quantization_params(TensorWrapper* tensor) const override;

292
293
294
295
296
297
298
  std::pair<TensorWrapper, py::object> create_tensor(const std::vector<size_t>& shape,
                                                     DType dtype) const override;

  std::pair<TensorWrapper, py::object> convert_and_update_tensor(py::object shape) const override;

  void quantize(const TensorWrapper& input, TensorWrapper& out,
                const std::optional<TensorWrapper>& noop_flag = std::nullopt) override;
299
300

  std::vector<size_t> get_scale_shape(const std::vector<size_t>& shape, bool columnwise) const;
301
302
};

303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
class NVFP4Quantizer : public Quantizer {
 public:
  // fp4 dtype
  DType dtype;
  // amax reduction for low precision FP4 AG
  bool with_amax_reduction;
  c10::intrusive_ptr<dist_group_type> amax_reduction_group;
  // random hadamard transform
  bool with_rht;
  bool with_post_rht_amax;
  // 2D block scaling
  bool with_2d_quantization;
  bool stochastic_rounding;

  int rht_matrix_random_sign_mask_t;
  at::Tensor rht_matrix;

  explicit NVFP4Quantizer(const py::handle& quantizer);

  NVTEScalingMode get_scaling_mode() const override { return NVTE_NVFP4_1D_SCALING; }

  void set_quantization_params(TensorWrapper* tensor) const override;

  std::pair<TensorWrapper, py::object> create_tensor(const std::vector<size_t>& shape,
                                                     DType dtype) const override;

  /*! @brief Construct an unquantized tensor that shares NVFP4 tensor's amax pointer
   *
   * The amax is zeroed out. Most TE kernels that output amax expect
   * amax to be initialized to zero.
   */
  std::pair<TensorWrapper, py::object> create_unquantized_tensor_with_amax(
      TensorWrapper& quantized_tensor, DType dtype);

  std::pair<TensorWrapper, py::object> convert_and_update_tensor(py::object shape) const override;

  void quantize(const TensorWrapper& input, TensorWrapper& out,
                const std::optional<TensorWrapper>& noop_flag = std::nullopt) override;

  /*! @brief Quantize to NVFP4, skipping local amax computation
   *
   * The input tensor's amax pointer is assumed to already hold the
   * local amax. The amax may still be reduced across the amax
   * reduction group.
   */
  void quantize_with_amax(TensorWrapper& input, TensorWrapper& out);

  std::vector<size_t> get_scale_shape(const std::vector<size_t>& shape, bool columnwise) const;

 private:
  void quantize_impl(const TensorWrapper& input, TensorWrapper& out,
                     const std::optional<TensorWrapper>& noop_flag, bool compute_amax);
};

357
358
std::unique_ptr<Quantizer> convert_quantizer(py::handle quantizer);

359
std::vector<size_t> getTensorShape(const at::Tensor& t);
Przemek Tredak's avatar
Przemek Tredak committed
360
361

transformer_engine::DType getTransformerEngineFP8Type(bool e4m3_if_hybrid,
362
                                                      const std::string& fp8_recipe);
Przemek Tredak's avatar
Przemek Tredak committed
363

364
inline size_t typeToNumBits(transformer_engine::DType t) {
365
366
  switch (t) {
    case transformer_engine::DType::kInt64:
367
      return 64;
368
369
    case transformer_engine::DType::kInt32:
    case transformer_engine::DType::kFloat32:
370
      return 32;
371
372
373
    case transformer_engine::DType::kInt16:
    case transformer_engine::DType::kFloat16:
    case transformer_engine::DType::kBFloat16:
374
      return 16;
375
376
377
    case transformer_engine::DType::kByte:
    case transformer_engine::DType::kFloat8E4M3:
    case transformer_engine::DType::kFloat8E5M2:
378
    case transformer_engine::DType::kFloat8E8M0:
379
    case transformer_engine::DType::kInt8:
380
      return 8;
381
    #if FP4_TYPE_SUPPORTED
382
383
    case transformer_engine::DType::kFloat4E2M1:
      return 4;
384
    #endif
385
    default:
386
      NVTE_ERROR("Invalid type (", static_cast<int>(t), ").");
387
388
389
  }
}

Przemek Tredak's avatar
Przemek Tredak committed
390
inline at::ScalarType GetATenDType(transformer_engine::DType t) {
391
  switch (t) {
392
393
    case transformer_engine::DType::kInt8:
      return torch::kInt8;
394
395
    case transformer_engine::DType::kInt16:
      return torch::kInt16;
396
397
398
399
400
401
402
403
404
405
406
    case transformer_engine::DType::kInt32:
      return torch::kInt32;
    case transformer_engine::DType::kInt64:
      return torch::kInt64;
    case transformer_engine::DType::kFloat32:
      return at::kFloat;
    case transformer_engine::DType::kFloat16:
      return at::kHalf;
    case transformer_engine::DType::kBFloat16:
      return at::kBFloat16;
    case transformer_engine::DType::kByte:
407
      return at::kByte;
408
    case transformer_engine::DType::kFloat8E4M3:
409
      return at::kFloat8_e4m3fn;
410
    case transformer_engine::DType::kFloat8E5M2:
411
      return at::kFloat8_e5m2;
412
413
    case transformer_engine::DType::kFloat8E8M0:
      return at::kByte;  // e8m0 dtype requires PyTorch 2.7.0+
414
    default:
415
      NVTE_ERROR("Invalid type (", static_cast<int>(t), ").");
416
  }
Przemek Tredak's avatar
Przemek Tredak committed
417
418
419
}

inline transformer_engine::DType GetTransformerEngineDType(at::ScalarType t) {
420
  switch (t) {
421
422
423
424
    case at::kFloat8_e4m3fn:
      return transformer_engine::DType::kFloat8E4M3;
    case at::kFloat8_e5m2:
      return transformer_engine::DType::kFloat8E5M2;
425
426
427
428
429
430
431
432
433
434
    case at::kHalf:
      return transformer_engine::DType::kFloat16;
    case at::kFloat:
      return transformer_engine::DType::kFloat32;
    case at::kBFloat16:
      return transformer_engine::DType::kBFloat16;
    case at::kBool:
      return transformer_engine::DType::kByte;
    case torch::kByte:
      return transformer_engine::DType::kByte;
435
436
    case torch::kInt16:
      return transformer_engine::DType::kInt16;
437
438
439
440
    case torch::kInt32:
      return transformer_engine::DType::kInt32;
    case torch::kInt64:
      return transformer_engine::DType::kInt64;
441
442
    case torch::kInt8:
      return transformer_engine::DType::kInt8;
443
    default:
444
      NVTE_ERROR("Invalid type (", static_cast<int>(t), ").");
445
  }
Przemek Tredak's avatar
Przemek Tredak committed
446
447
448
}

inline transformer_engine::DType GetTransformerEngineDType(int DType_value) {
449
  return static_cast<transformer_engine::DType>(DType_value);
Przemek Tredak's avatar
Przemek Tredak committed
450
451
452
453
}

transformer_engine::TensorWrapper makeTransformerEngineTensor(void* data_ptr,
                                                              const std::vector<size_t>& shape,
454
                                                              const transformer_engine::DType type);
Przemek Tredak's avatar
Przemek Tredak committed
455

456
457
458
459
460
461
462
463
464
465
466
467
transformer_engine::TensorWrapper makeTransformerEngineTensor(
    void* data_ptr, const std::vector<size_t>& shape, const transformer_engine::DType type,
    void* amax_ptr, void* scale_ptr, void* scale_inv_ptr, std::vector<size_t> scale_inv_shape = {1},
    NVTEScalingMode scaling_mode = NVTE_DELAYED_TENSOR_SCALING);

transformer_engine::TensorWrapper makeTransformerEngineTensor(
    void* data_ptr, void* columnwise_data_ptr, const std::vector<size_t>& shape,
    const std::vector<size_t>& columnwise_shape, const transformer_engine::DType type,
    void* amax_ptr, void* scale_ptr, void* scale_inv_ptr, void* columnwise_scale_inv_ptr,
    const std::vector<size_t>& scale_inv_shape = {1},
    const std::vector<size_t>& columnwise_scale_inv_shape = {1},
    NVTEScalingMode scaling_mode = NVTE_DELAYED_TENSOR_SCALING);
Przemek Tredak's avatar
Przemek Tredak committed
468
469
470

transformer_engine::TensorWrapper makeTransformerEngineTensor(void* data_ptr,
                                                              const NVTEShape& shape,
471
                                                              const transformer_engine::DType type);
Przemek Tredak's avatar
Przemek Tredak committed
472
473
474

transformer_engine::TensorWrapper makeTransformerEngineTensor(at::Tensor tensor);

475
476
477
478
std::tuple<std::vector<transformer_engine::TensorWrapper>, std::vector<std::vector<NVTETensor>>,
           std::vector<NVTETensor*>, size_t, size_t>
makeTransformerEngineTensorList(std::vector<std::vector<at::Tensor>> at_tensor_lists);

479
480
481
482
483
484
485
486
TensorWrapper makeTransformerEngineTensor(py::handle tensor, py::handle quantizer);

transformer_engine::TensorWrapper makeTransformerEngineTensor(
    at::Tensor tensor, at::Tensor amax, const at::Tensor scale, at::Tensor scale_inv,
    NVTEScalingMode scaling_mode = NVTE_DELAYED_TENSOR_SCALING);

template <typename T>
T product(const std::vector<T>& shape);
487

488
489
490
size_t product(const NVTEShape& shape, size_t begin, size_t end);

std::vector<size_t> nvte_shape_to_vector(const NVTEShape& nvte_shape);
Przemek Tredak's avatar
Przemek Tredak committed
491

492
at::Tensor allocateSpace(const std::vector<size_t>& shape, const transformer_engine::DType type,
cyanguwa's avatar
cyanguwa committed
493
                         bool init_to_zeros);
Przemek Tredak's avatar
Przemek Tredak committed
494

495
at::Tensor allocateSpace(const NVTEShape& shape, const transformer_engine::DType type,
Przemek Tredak's avatar
Przemek Tredak committed
496
497
                         bool init_to_zeros = false);

498
at::Tensor allocateTorchTensor(int M, int N, transformer_engine::DType dtype);
Przemek Tredak's avatar
Przemek Tredak committed
499

500
at::Tensor allocateTorchTensor(int M, transformer_engine::DType dtype);
Przemek Tredak's avatar
Przemek Tredak committed
501

502
void* getDataPtr(at::Tensor tensor, int offset = 0);
503

504
505
std::vector<size_t> convertShape(const NVTEShape& shape);

506
507
508
size_t roundup(size_t value, size_t multiple);

size_t ceildiv(size_t numer, size_t denom);
509

510
NVTEShape convertTorchShape(const c10::IntArrayRef torch_shape);
511
512
513
514
515
516
517
518
519

std::vector<size_t> convert_shape_back_from_fp4(const std::vector<size_t>& shape, bool transpose);

// unpack the PhiloxCudaState into CUDA tensor
void philox_unpack(at::PhiloxCudaState arg, int64_t* rng_state_ptr);

// extract PhiloxCudaState from CUDA random number generator
at::PhiloxCudaState init_philox_state(at::CUDAGeneratorImpl* gen, size_t elts_per_thread);

520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
}  // namespace transformer_engine::pytorch

namespace std {
template <typename T>
string to_string(const vector<T>& vec) {
  string ret = "[";
  for (const auto& val : vec) {
    ret += to_string(val) + ",";
  }
  if (ret.size() > 1) {
    ret[ret.size() - 1] = ']';
  } else {
    ret += "]";
  }
  return ret;
}

// Torch shape -> string
template <typename T>
string to_string(const c10::ArrayRef<T>& vec) {
  string ret = "[";
  for (const auto& val : vec) {
    ret += to_string(val) + ",";
  }
  if (ret.size() > 1) {
    ret[ret.size() - 1] = ']';
  } else {
    ret += "]";
  }
  return ret;
}

inline string to_string(const NVTEShape& s) {
  string ret = "[";
  for (size_t i = 0; i < s.ndim; ++i) {
    ret += to_string(s.data[i]) + ",";
  }
  if (ret.size() > 1) {
    ret[ret.size() - 1] = ']';
  } else {
    ret += "]";
  }
  return ret;
}
}  // namespace std

Przemek Tredak's avatar
Przemek Tredak committed
566
#endif  // TRANSFORMER_ENGINE_PYTORCH_CSRC_COMMON_H_