conversion_utils.h 12.1 KB
Newer Older
aiss's avatar
aiss committed
1
2
3
4
// Copyright (c) Microsoft Corporation.
// SPDX-License-Identifier: Apache-2.0

// DeepSpeed Team
aiss's avatar
aiss committed
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265

#pragma once

#include "ds_kernel_utils.h"

#include <cuda_fp16.h>
#include <stdint.h>

#ifdef BF16_AVAILABLE
#include <cuda_bf16.h>
#endif

namespace conversion {

// Basic primitive for constructing conversions
template <typename TO, typename FROM>
DS_D_INLINE TO to(FROM val)
{
    return to(val);
}

// Specializations

/********************* Identity Conversions *********************/
/*
Identity conversions are useful in templated functions where we might have
a fixed destination type. For example, I might have a kernel that accepts
__half, __nv_bfloat16, and float but always want to do the core computation
at floating point:

T mem_value = input[idx];
float compute_value = conversion::to<float, T>(mem_value);

In practice, we should be able to elide the second template parameter:
float compute_val = conversion::to<float>(mem_value);

In this case, we need an implementation to handle the T = float case

NOTE: The type inferencing system appears to be unable to handle inferring the first
template parameter, even in the trivial case.
*/

// Floating point types
template <>
DS_D_INLINE double to(double val)
{
    return val;
}
template <>
DS_D_INLINE float to(float val)
{
    return val;
}
template <>
DS_D_INLINE __half to(__half val)
{
    return val;
}
#ifdef BF16_AVAILABLE
template <>
DS_D_INLINE __nv_bfloat16 to(__nv_bfloat16 val)
{
    return val;
}
#endif

// Integer types
template <>
DS_D_INLINE int8_t to(int8_t val)
{
    return val;
}
template <>
DS_D_INLINE uint8_t to(uint8_t val)
{
    return val;
}
template <>
DS_D_INLINE int16_t to(int16_t val)
{
    return val;
}
template <>
DS_D_INLINE uint16_t to(uint16_t val)
{
    return val;
}
template <>
DS_D_INLINE int32_t to(int32_t val)
{
    return val;
}
template <>
DS_D_INLINE uint32_t to(uint32_t val)
{
    return val;
}
template <>
DS_D_INLINE int64_t to(int64_t val)
{
    return val;
}
template <>
DS_D_INLINE uint64_t to(uint64_t val)
{
    return val;
}

// TODO: evaluate if we want bools

/*********************  To Double Conversions *********************/

// * to double variants

// Would normally like to not use C cast, but this is an important enough conversion
// to keep
template <>
DS_D_INLINE double to(float val)
{
#ifdef PTX_AVAILABLE
    double ret_val;
    asm("ctv.rn.f64.f32 %0, %1;\n" : "=d"(ret_val) : "f"(val));
    return ret_val;
#else
    return double(val);
#endif
}
// Note: there is a CVT instruction for __half -> double, but there's no inline interface
// for passing a single half value
template <>
DS_D_INLINE double to(__half val)
{
    return to<double>(__half2float(val));
}
template <>
DS_D_INLINE double to(int64_t val)
{
    return __ll2double_rn(val);
}
template <>
DS_D_INLINE double to(int32_t val)
{
    return __int2double_rn(val);
}
template <>
DS_D_INLINE double to(int16_t val)
{
    return __int2double_rn(val);
}
template <>
DS_D_INLINE double to(int8_t val)
{
    return __int2double_rn(val);
}
template <>
DS_D_INLINE double to(uint64_t val)
{
    return __ull2double_rn(val);
}
template <>
DS_D_INLINE double to(uint32_t val)
{
    return __uint2double_rn(val);
}
template <>
DS_D_INLINE double to(uint16_t val)
{
    return __uint2double_rn(val);
}
template <>
DS_D_INLINE double to(uint8_t val)
{
    return __uint2double_rn(val);
}

// Same applies here
#ifdef BF16_AVAILABLE
template <>
DS_D_INLINE double to(__nv_bfloat16 val)
{
    return to<double>(__bfloat162float(val));
}
#endif

/*********************  To Float Conversions *********************/

template <>
DS_D_INLINE float to(double val)
{
    return __double2float_rn(val);
}
template <>
DS_D_INLINE float to(__half val)
{
    return __half2float(val);
}
template <>
DS_D_INLINE float to(int64_t val)
{
    return __ll2float_rn(val);
}
template <>
DS_D_INLINE float to(int32_t val)
{
    return __int2float_rn(val);
}
template <>
DS_D_INLINE float to(int16_t val)
{
    return __int2float_rn(val);
}
template <>
DS_D_INLINE float to(int8_t val)
{
    return __int2float_rn(val);
}
template <>
DS_D_INLINE float to(uint64_t val)
{
    return __ull2float_rn(val);
}
template <>
DS_D_INLINE float to(uint32_t val)
{
    return __uint2float_rn(val);
}
template <>
DS_D_INLINE float to(uint16_t val)
{
    return __uint2float_rn(val);
}
template <>
DS_D_INLINE float to(uint8_t val)
{
    return __uint2float_rn(val);
}

#ifdef BF16_AVAILABLE
template <>
DS_D_INLINE float to(__nv_bfloat16 val)
{
    return __bfloat162float(val);
}
#endif

/*********************  To Float2 Conversions *********************/
template <>
DS_D_INLINE float2 to(__half2 val)
{
    return __half22float2(val);
}

#ifdef BF16_AVAILABLE
template <>
DS_D_INLINE float2 to(__nv_bfloat162 val)
{
    return __bfloat1622float2(val);
}
#endif

/*********************  To Half Conversions *********************/
aiss's avatar
aiss committed
266
267
268
269
270
271
272
273
274
275
template <>
DS_D_INLINE __half to(double val)
{
#ifdef __HIP_PLATFORM_HCC__
    float val_f = __double2float_rn(val);
    return __float2half(val_f);
#else
    return __double2half(val);
#endif
}
aiss's avatar
aiss committed
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
template <>
DS_D_INLINE __half to(float val)
{
    return __float2half(val);
}
template <>
DS_D_INLINE __half to(int64_t val)
{
    return __ll2half_rn(val);
}
template <>
DS_D_INLINE __half to(int32_t val)
{
    return __int2half_rn(val);
}
template <>
DS_D_INLINE __half to(int16_t val)
{
    return __short2half_rn(val);
}
template <>
DS_D_INLINE __half to(int8_t val)
{
    return __int2half_rn(val);
}
template <>
DS_D_INLINE __half to(uint64_t val)
{
    return __ull2half_rn(val);
}
template <>
DS_D_INLINE __half to(uint32_t val)
{
    return __uint2half_rn(val);
}
template <>
DS_D_INLINE __half to(uint16_t val)
{
    return __ushort2half_rn(val);
}
template <>
DS_D_INLINE __half to(uint8_t val)
{
    return __uint2half_rn(val);
}

#ifdef BF16_AVAILABLE
// No direct conversion
template <>
DS_D_INLINE __half to(__nv_bfloat16 val)
{
    return to<__half>(to<float>(val));
}
#endif

/*********************  To Half2 Conversions *********************/
template <>
DS_D_INLINE __half2 to(float2 val)
{
    return __float22half2_rn(val);
}
aiss's avatar
aiss committed
337
338
339
340
341
template <>
DS_D_INLINE __half2 to(float val)
{
    return __float2half2_rn(val);
}
aiss's avatar
aiss committed
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413

#ifdef BF16_AVAILABLE
// No direct conversion
template <>
DS_D_INLINE __half2 to(__nv_bfloat162 val)
{
    return to<__half2>(to<float2>(val));
}
#endif

/*********************  To BF16 Conversions *********************/
#ifdef BF16_AVAILABLE
template <>
DS_D_INLINE __nv_bfloat16 to(double val)
{
    return __double2bfloat16(val);
}
template <>
DS_D_INLINE __nv_bfloat16 to(float val)
{
    return __float2bfloat16(val);
}
template <>
DS_D_INLINE __nv_bfloat16 to(int64_t val)
{
    return __ll2bfloat16_rn(val);
}
template <>
DS_D_INLINE __nv_bfloat16 to(int32_t val)
{
    return __int2bfloat16_rn(val);
}
template <>
DS_D_INLINE __nv_bfloat16 to(int16_t val)
{
    return __short2bfloat16_rn(val);
}
template <>
DS_D_INLINE __nv_bfloat16 to(int8_t val)
{
    return __int2bfloat16_rn(val);
}
template <>
DS_D_INLINE __nv_bfloat16 to(uint64_t val)
{
    return __ull2bfloat16_rn(val);
}
template <>
DS_D_INLINE __nv_bfloat16 to(uint32_t val)
{
    return __uint2bfloat16_rn(val);
}
template <>
DS_D_INLINE __nv_bfloat16 to(uint16_t val)
{
    return __ushort2bfloat16_rn(val);
}
template <>
DS_D_INLINE __nv_bfloat16 to(uint8_t val)
{
    return __uint2bfloat16_rn(val);
}
#endif

/*********************  To BF162 Conversions *********************/
#ifdef BF16_AVAILABLE
template <>
DS_D_INLINE __nv_bfloat162 to(float2 val)
{
    return __float22bfloat162_rn(val);
}
template <>
aiss's avatar
aiss committed
414
415
416
417
418
DS_D_INLINE __nv_bfloat162 to(float val)
{
    return __float2bfloat162_rn(val);
}
template <>
aiss's avatar
aiss committed
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
DS_D_INLINE __nv_bfloat162 to(__half2 val)
{
    return to<__nv_bfloat162>(to<float2>(val));
}
#endif

/*********************  To INT64_T Conversions *********************/
template <>
DS_D_INLINE int64_t to(double val)
{
    return __double2ll_rn(val);
}
template <>
DS_D_INLINE int64_t to(float val)
{
    return __float2ll_rn(val);
}
template <>
DS_D_INLINE int64_t to(__half val)
{
    return __half2ll_rn(val);
}
// No direct support for integer casts at the C++ level and I don't feel they're so important
// to demand an PTX at this time

#ifdef BF16_AVAILABLE
template <>
DS_D_INLINE int64_t to(__nv_bfloat16 val)
{
    return __bfloat162ll_rn(val);
}
#endif

/*********************  To INT32_T Conversions *********************/
template <>
DS_D_INLINE int32_t to(double val)
{
    return __double2int_rn(val);
}
template <>
DS_D_INLINE int32_t to(float val)
{
    return __float2int_rn(val);
}
template <>
DS_D_INLINE int32_t to(__half val)
{
    return __half2int_rn(val);
}
// No direct support for integer casts at the C++ level and I don't feel they're so important
// to demand an PTX at this time

#ifdef BF16_AVAILABLE
template <>
DS_D_INLINE int32_t to(__nv_bfloat16 val)
{
    return __bfloat162int_rn(val);
}
#endif

/*********************  To INT16_T Conversions *********************/
template <>
DS_D_INLINE int16_t to(double val)
{
    return __double2int_rn(val);
}
template <>
DS_D_INLINE int16_t to(float val)
{
    return __float2int_rn(val);
}
template <>
DS_D_INLINE int16_t to(__half val)
{
    return __half2int_rn(val);
}
// No direct support for integer casts at the C++ level and I don't feel they're so important
// to demand an PTX at this time

#ifdef BF16_AVAILABLE
template <>
DS_D_INLINE int16_t to(__nv_bfloat16 val)
{
    return __bfloat162int_rn(val);
}
#endif

/*********************  To INT8_T Conversions *********************/
template <>
DS_D_INLINE int8_t to(double val)
{
    return __double2int_rn(val);
}
template <>
DS_D_INLINE int8_t to(float val)
{
    return __float2int_rn(val);
}
template <>
DS_D_INLINE int8_t to(__half val)
{
    return __half2int_rn(val);
}
// No direct support for integer casts at the C++ level and I don't feel they're so important
// to demand an PTX at this time

#ifdef BF16_AVAILABLE
template <>
DS_D_INLINE int8_t to(__nv_bfloat16 val)
{
    return __bfloat162int_rn(val);
}
#endif

/*********************  To UINT64_T Conversions *********************/
template <>
DS_D_INLINE uint64_t to(double val)
{
    return __double2ull_rn(val);
}
template <>
DS_D_INLINE uint64_t to(float val)
{
    return __float2ull_rn(val);
}
template <>
DS_D_INLINE uint64_t to(__half val)
{
    return __half2ull_rn(val);
}
// No direct support for integer casts at the C++ level and I don't feel they're so important
// to demand an PTX at this time

#ifdef BF16_AVAILABLE
template <>
DS_D_INLINE uint64_t to(__nv_bfloat16 val)
{
    return __bfloat162ull_rn(val);
}
#endif

/*********************  To UINT32_T Conversions *********************/
template <>
DS_D_INLINE uint32_t to(double val)
{
    return __double2uint_rn(val);
}
template <>
DS_D_INLINE uint32_t to(float val)
{
    return __float2uint_rn(val);
}
template <>
DS_D_INLINE uint32_t to(__half val)
{
    return __half2uint_rn(val);
}
// No direct support for integer casts at the C++ level and I don't feel they're so important
// to demand an PTX at this time

#ifdef BF16_AVAILABLE
template <>
DS_D_INLINE uint32_t to(__nv_bfloat16 val)
{
    return __bfloat162uint_rn(val);
}
#endif

/*********************  To UINT16_T Conversions *********************/
template <>
DS_D_INLINE uint16_t to(double val)
{
    return __double2uint_rn(val);
}
template <>
DS_D_INLINE uint16_t to(float val)
{
    return __float2uint_rn(val);
}
template <>
DS_D_INLINE uint16_t to(__half val)
{
    return __half2uint_rn(val);
}
// No direct support for integer casts at the C++ level and I don't feel they're so important
// to demand an PTX at this time

#ifdef BF16_AVAILABLE
template <>
DS_D_INLINE uint16_t to(__nv_bfloat16 val)
{
    return __bfloat162uint_rn(val);
}
#endif

/*********************  To UINT8_T Conversions *********************/
template <>
DS_D_INLINE uint8_t to(double val)
{
    return __double2uint_rn(val);
}
template <>
DS_D_INLINE uint8_t to(float val)
{
    return __float2uint_rn(val);
}
template <>
DS_D_INLINE uint8_t to(__half val)
{
    return __half2uint_rn(val);
}
// No direct support for integer casts at the C++ level and I don't feel they're so important
// to demand an PTX at this time

#ifdef BF16_AVAILABLE
template <>
DS_D_INLINE uint8_t to(__nv_bfloat16 val)
{
    return __bfloat162uint_rn(val);
}
#endif

}  // namespace conversion