mps_ops.mm 1.69 KB
Newer Older
1
2
3
4
5
6
7
#import <MetalPerformanceShadersGraph/MetalPerformanceShadersGraph.h>

#define HLF_MAX 65504
#define TH 1024
#define NUM 4
#define NUM_BLOCK 4096

8
9
10
11
12
13
static inline MPSGraph* get_graph() {
    static MPSGraph* cur = nil;
    if (!cur) {
        cur = [[MPSGraph alloc] init];
    }
    return cur;
14
15
}

16
17
18
19
20
21
22
23
24
25
26
static inline id<MTLDevice> get_device() {
    NSError* error = nil;
    static id<MTLDevice> device = nil;
    if (!device) {
        device = MTLCreateSystemDefaultDevice();
    }
    if (!device) {
        NSLog(@"Failed to get MPS device");
        abort();
    }
    return device;
27
28
}

29
30
31
32
33
34
35
36
37
38
39
static inline id<MTLLibrary> get_library() {
    NSError* error = nil;
    static id<MTLLibrary> library = nil;
    if (!library) {
        library = [get_device() newLibraryWithURL:[NSURL fileURLWithPath:@"bitsandbytes.metallib"] error:&error];
    }
    if (!library) {
        NSLog(@"Failed to load bitsandbytes.metallib");
        abort();
    }
    return library;
40
41
42
43
}

/*MPSGraphTensor* dequantize_mps(MPSGraphTensor* code, MPSGraphTensor* A, int n)
{
44
45
  id out = [get_graph() dequantizeTensor:(MPSGraphTensor*)A scaleTensor:(MPSGraphTensor*)code zeroPoint:0.0
dataType:MPSDataTypeInt8 axis:0 name:@"out"]; return out;
46
47
48
}*/

// MPSGraph function for quantize
49
50
51
52
53
54
55
56
57
58
extern "C" MPSGraphTensor* quantize_mps(MPSGraph* graph, MPSGraphTensor* code, MPSGraphTensor* A, int n) {
    id<MTLDevice> device = get_device();
    id<MTLLibrary> library = get_library();
    static id<MTLFunction> kernel = nil;
    if (!kernel) {
        kernel = [library newFunctionWithName:@"quantize"];
        if (!kernel) {
            NSLog(@"Failed to load bitsandbytes.metallib");
            abort();
        }
59
    }
60
61
    NSLog(@"Not implemented");
    return nil;
62
}