mps_ops.mm 1.59 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
#import <MetalPerformanceShadersGraph/MetalPerformanceShadersGraph.h>

#define HLF_MAX 65504
#define TH 1024
#define NUM 4
#define NUM_BLOCK 4096

static inline MPSGraph* get_graph()
{
  static MPSGraph* cur = nil;
  if(!cur) {
    cur = [[MPSGraph alloc] init];
  }
  return cur;
}

static inline id<MTLDevice> get_device()
{
  NSError *error = nil;
  static id<MTLDevice> device = nil;
  if(!device) {
    device = MTLCreateSystemDefaultDevice();
  }
  if(!device) {
    NSLog(@"Failed to get MPS device");
    abort();
  }
  return device;
}

static inline id<MTLLibrary> get_library()
{
  NSError *error = nil;
  static id<MTLLibrary> library = nil;
  if(!library) {
    library = [get_device() newLibraryWithURL:[NSURL fileURLWithPath:@"bitsandbytes.metallib"] error:&error];
  }
  if(!library) {
    NSLog(@"Failed to load bitsandbytes.metallib");
    abort();
  }
  return library;
}

/*MPSGraphTensor* dequantize_mps(MPSGraphTensor* code, MPSGraphTensor* A, int n)
{
  id out = [get_graph() dequantizeTensor:(MPSGraphTensor*)A scaleTensor:(MPSGraphTensor*)code zeroPoint:0.0 dataType:MPSDataTypeInt8 axis:0 name:@"out"];
  return out;
}*/


// MPSGraph function for quantize
extern "C" MPSGraphTensor* quantize_mps(MPSGraph* graph, MPSGraphTensor* code, MPSGraphTensor* A, int n)
{
  id<MTLDevice> device = get_device();
  id<MTLLibrary> library = get_library();
  static id<MTLFunction> kernel = nil;
  if(!kernel) {
    kernel = [library newFunctionWithName:@"quantize"];
    if(!kernel) {
      NSLog(@"Failed to load bitsandbytes.metallib");
      abort();
    }
  }
  NSLog(@"Not implemented");
  return nil;
}