mps_kernels.metal 2.73 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
#include <metal_stdlib>
using namespace metal;

#define HLF_MAX 65504
#define TH 1024
#define NUM 4
#define NUM_BLOCK 4096

template<bool STOCHASTIC>
static unsigned char quantize_scalar(
  float rand,
  device float* code,
  float x)
{
    int pivot = 127;
    int upper_pivot = 255;
    int lower_pivot = 0;

    float lower = -1.0f;
    float upper = 1.0f;

    float val = code[pivot];
    // i>>=1 = {32, 16, 8, 4, 2, 1}
    for(int i = 64; i > 0; i>>=1)
    {
        if(x > val)
        {
            lower_pivot = pivot;
            lower = val;
            pivot+=i;
        }
        else
        {
            upper_pivot = pivot;
            upper = val;
            pivot-=i;
        }
        val = code[pivot];
    }

    if(upper_pivot == 255)
        upper = code[upper_pivot];
    if(lower_pivot == 0)
        lower = code[lower_pivot];

    if(!STOCHASTIC)
    {
      if(x > val)
      {
        float midpoint = (upper+val)*0.5f;
        if(x > midpoint)
        {
          return upper_pivot;
        }
        else
          return pivot;
      }
      else
      {
        float midpoint = (lower+val)*0.5f;
        if(x < midpoint)
          return lower_pivot;
        else
          return pivot;
      }
    }
    else
    {
      if(x > val)
      {
        float dist_to_upper = fabs(upper-x);
        float dist_full = upper-val;
        if(rand >= dist_to_upper/dist_full) return upper_pivot;
        else return pivot;
      }
      else
      {
        float dist_to_lower = fabs(lower-x);
        float dist_full = val-lower;
        if(rand >= dist_to_lower/dist_full) return lower_pivot;
        else return pivot;
      }
    }
}

kernel void quantize(device float* code [[buffer(0)]],
                      device float* A [[buffer(1)]],
                      device uchar* out [[buffer(2)]],
                      constant uint& n [[buffer(3)]],
                      uint id [[thread_position_in_grid]]) {
  const uint n_full = (NUM_BLOCK * (n / NUM_BLOCK)) + (n % NUM_BLOCK == 0 ? 0 : NUM_BLOCK);
  uint valid_items = (id / NUM_BLOCK + 1 == (n + NUM_BLOCK - 1) / NUM_BLOCK) ? n - (id / NUM_BLOCK * NUM_BLOCK) : NUM_BLOCK;
  const uint base_idx = (id / NUM_BLOCK * NUM_BLOCK);

  float vals[NUM];
  uchar qvals[NUM];

  for (uint i = base_idx; i < n_full; i += ((n + NUM_BLOCK - 1) / NUM_BLOCK) * NUM_BLOCK) {
    valid_items = n - i > NUM_BLOCK ? NUM_BLOCK : n - i;

    threadgroup_barrier(mem_flags::mem_threadgroup);

    for (uint j = 0; j < valid_items; j++) {
      vals[j] = A[i + j];
    }

    for (uint j = 0; j < valid_items; j++) {
      qvals[j] = quantize_scalar<false>(0.0f, code, vals[j]);
    }

    threadgroup_barrier(mem_flags::mem_threadgroup);

    for (uint j = 0; j < valid_items; j++) {
      out[i + j] = qvals[j];
    }
  }
}