ShaderLogisticRegression.comp 1.41 KB
Newer Older
mashun1's avatar
v1  
mashun1 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
#version 450

layout (constant_id = 0) const float m = 0;

layout (local_size_x = 1) in;

layout(set = 0, binding = 0) buffer bxi { float xi[]; };
layout(set = 0, binding = 1) buffer bxj { float xj[]; };
layout(set = 0, binding = 2) buffer by { float y[]; };
layout(set = 0, binding = 3) buffer bwin { float win[]; };
layout(set = 0, binding = 4) buffer bwouti { float wouti[]; };
layout(set = 0, binding = 5) buffer bwoutj { float woutj[]; };
layout(set = 0, binding = 6) buffer bbin { float bin[]; };
layout(set = 0, binding = 7) buffer bbout { float bout[]; };
layout(set = 0, binding = 8) buffer blout { float lout[]; };

float sigmoid(float z) {
    return 1.0 / (1.0 + exp(-z));
}

float inference(vec2 x, vec2 w, float b) {
    // Compute the linear mapping function
    float z = dot(w, x) + b;
    // Calculate the y-hat with sigmoid
    float yHat = sigmoid(z);
    return yHat;
}

float calculateLoss(float yHat, float y) {
    return -(y * log(yHat)  +  (1.0 - y) * log(1.0 - yHat));
}

void main() {
    uint idx = gl_GlobalInvocationID.x;

    vec2 wCurr = vec2(win[0], win[1]);
    float bCurr = bin[0];

    vec2 xCurr = vec2(xi[idx], xj[idx]);
    float yCurr = y[idx];

    float yHat = inference(xCurr, wCurr, bCurr);

    float dZ = yHat - yCurr;
    vec2 dW = (1. / m) * xCurr * dZ;
    float dB = (1. / m) * dZ;
    wouti[idx] = dW.x;
    woutj[idx] = dW.y;
    bout[idx] = dB;

    lout[idx] = calculateLoss(yHat, yCurr);
}