Unverified Commit 0d140bd1 authored by Michael Yang's avatar Michael Yang Committed by GitHub
Browse files

fix: conv2d bias (#12834)

parent 93e45f0f
...@@ -10,7 +10,8 @@ type Conv2D struct { ...@@ -10,7 +10,8 @@ type Conv2D struct {
func (m *Conv2D) Forward(ctx ml.Context, t ml.Tensor, s0, s1, p0, p1, d0, d1 int) ml.Tensor { func (m *Conv2D) Forward(ctx ml.Context, t ml.Tensor, s0, s1, p0, p1, d0, d1 int) ml.Tensor {
t = m.Weight.Conv2D(ctx, t, s0, s1, p0, p1, d0, d1) t = m.Weight.Conv2D(ctx, t, s0, s1, p0, p1, d0, d1)
if m.Bias != nil { if m.Bias != nil {
t = t.Add(ctx, m.Bias) // Bias shape is (out_channels,) while t shape is (width, height, out_channels, batch)
t = t.Add(ctx, m.Bias.Reshape(ctx, 1, 1, -1))
} }
return t return t
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment