Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
5dba2575
Unverified
Commit
5dba2575
authored
Jan 03, 2025
by
wchen61
Committed by
GitHub
Jan 02, 2025
Browse files
Resolve race conditions in Marlin kernel (#11493)
Signed-off-by:
wchen61
<
wchen61@foxmail.com
>
parent
187e3299
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
21 additions
and
19 deletions
+21
-19
csrc/quantization/gptq_marlin/gptq_marlin.cu
csrc/quantization/gptq_marlin/gptq_marlin.cu
+21
-19
No files found.
csrc/quantization/gptq_marlin/gptq_marlin.cu
View file @
5dba2575
...
@@ -834,6 +834,7 @@ __global__ void Marlin(
...
@@ -834,6 +834,7 @@ __global__ void Marlin(
int4
*
sh_g_idx
=
sh_b
+
(
stages
*
b_sh_stage
);
int4
*
sh_g_idx
=
sh_b
+
(
stages
*
b_sh_stage
);
int4
*
sh_zp
=
sh_g_idx
+
(
stages
*
g_idx_stage
);
int4
*
sh_zp
=
sh_g_idx
+
(
stages
*
g_idx_stage
);
int4
*
sh_s
=
sh_zp
+
(
stages
*
zp_sh_stage
);
int4
*
sh_s
=
sh_zp
+
(
stages
*
zp_sh_stage
);
int4
*
sh_red
=
sh_s
+
(
stages
*
s_sh_stage
);
// Register storage for double buffer of shared memory reads.
// Register storage for double buffer of shared memory reads.
FragA
frag_a
[
2
][
thread_m_blocks
];
FragA
frag_a
[
2
][
thread_m_blocks
];
...
@@ -932,11 +933,11 @@ __global__ void Marlin(
...
@@ -932,11 +933,11 @@ __global__ void Marlin(
int4
*
sh_s_stage
=
sh_s
+
s_sh_stage
*
pipe
;
int4
*
sh_s_stage
=
sh_s
+
s_sh_stage
*
pipe
;
if
constexpr
(
group_blocks
>=
thread_k_blocks
)
{
if
constexpr
(
group_blocks
>=
thread_k_blocks
)
{
// Only fetch scales if this tile starts a new group
if
(
pipe
%
(
group_blocks
/
thread_k_blocks
)
==
0
)
{
if
(
s_sh_wr_pred
)
{
if
(
s_sh_wr_pred
)
{
cp_async4
(
&
sh_s_stage
[
s_sh_wr
],
&
scales_ptr
[
s_gl_rd
]);
cp_async4
(
&
sh_s_stage
[
s_sh_wr
],
&
scales_ptr
[
s_gl_rd
]);
}
}
// Only fetch scales if this tile starts a new group
if
((
pipe
+
1
)
%
(
group_blocks
/
thread_k_blocks
)
==
0
)
{
s_gl_rd
+=
s_gl_rd_delta
;
s_gl_rd
+=
s_gl_rd_delta
;
}
}
}
else
{
}
else
{
...
@@ -1038,9 +1039,7 @@ __global__ void Marlin(
...
@@ -1038,9 +1039,7 @@ __global__ void Marlin(
// No act-order case
// No act-order case
if
constexpr
(
group_blocks
!=
-
1
)
{
if
constexpr
(
group_blocks
!=
-
1
)
{
if
constexpr
(
group_blocks
>=
thread_k_blocks
)
{
if
constexpr
(
group_blocks
>=
thread_k_blocks
)
{
int4
*
sh_s_stage
=
int4
*
sh_s_stage
=
sh_s
+
s_sh_stage
*
pipe
;
sh_s
+
s_sh_stage
*
((
group_blocks
/
thread_k_blocks
)
*
(
pipe
/
(
group_blocks
/
thread_k_blocks
)));
reinterpret_cast
<
int4
*>
(
&
frag_s
[
k
%
2
])[
0
]
=
sh_s_stage
[
s_sh_rd
];
reinterpret_cast
<
int4
*>
(
&
frag_s
[
k
%
2
])[
0
]
=
sh_s_stage
[
s_sh_rd
];
}
else
{
}
else
{
int
warp_id
=
threadIdx
.
x
/
32
;
int
warp_id
=
threadIdx
.
x
/
32
;
...
@@ -1339,15 +1338,15 @@ __global__ void Marlin(
...
@@ -1339,15 +1338,15 @@ __global__ void Marlin(
int
red_sh_wr
=
int
red_sh_wr
=
red_sh_delta
*
j
+
(
red_sh_rd
-
red_sh_stride
*
i
);
red_sh_delta
*
j
+
(
red_sh_rd
-
red_sh_stride
*
i
);
if
(
i
<
red_off
)
{
if
(
i
<
red_off
)
{
float
*
c_rd
=
float
*
c_rd
=
reinterpret_cast
<
float
*>
(
reinterpret_cast
<
float
*>
(
&
sh
[
red_sh_delta
*
j
+
red_sh_rd
]);
&
sh_red
[
red_sh_delta
*
j
+
red_sh_rd
]);
float
*
c_wr
=
reinterpret_cast
<
float
*>
(
&
sh
[
red_sh_wr
]);
float
*
c_wr
=
reinterpret_cast
<
float
*>
(
&
sh
_red
[
red_sh_wr
]);
#pragma unroll
#pragma unroll
for
(
int
k
=
0
;
k
<
4
;
k
++
)
for
(
int
k
=
0
;
k
<
4
;
k
++
)
reinterpret_cast
<
FragC
*>
(
frag_c
)[
4
*
2
*
m_block
+
j
][
k
]
+=
reinterpret_cast
<
FragC
*>
(
frag_c
)[
4
*
2
*
m_block
+
j
][
k
]
+=
c_rd
[
k
]
+
c_wr
[
k
];
c_rd
[
k
]
+
c_wr
[
k
];
}
}
sh
[
red_sh_wr
]
=
sh
_red
[
red_sh_wr
]
=
reinterpret_cast
<
int4
*>
(
&
frag_c
)[
4
*
2
*
m_block
+
j
];
reinterpret_cast
<
int4
*>
(
&
frag_c
)[
4
*
2
*
m_block
+
j
];
}
}
}
}
...
@@ -1357,7 +1356,7 @@ __global__ void Marlin(
...
@@ -1357,7 +1356,7 @@ __global__ void Marlin(
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
4
*
2
;
i
++
)
{
for
(
int
i
=
0
;
i
<
4
*
2
;
i
++
)
{
float
*
c_rd
=
float
*
c_rd
=
reinterpret_cast
<
float
*>
(
&
sh
[
red_sh_delta
*
i
+
red_sh_rd
]);
reinterpret_cast
<
float
*>
(
&
sh
_red
[
red_sh_delta
*
i
+
red_sh_rd
]);
#pragma unroll
#pragma unroll
for
(
int
j
=
0
;
j
<
4
;
j
++
)
for
(
int
j
=
0
;
j
<
4
;
j
++
)
reinterpret_cast
<
FragC
*>
(
frag_c
)[
4
*
2
*
m_block
+
i
][
j
]
+=
reinterpret_cast
<
FragC
*>
(
frag_c
)[
4
*
2
*
m_block
+
i
][
j
]
+=
...
@@ -1397,7 +1396,7 @@ __global__ void Marlin(
...
@@ -1397,7 +1396,7 @@ __global__ void Marlin(
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
thread_m_blocks
*
4
;
i
++
)
{
for
(
int
i
=
0
;
i
<
thread_m_blocks
*
4
;
i
++
)
{
cp_async4_pred
(
cp_async4_pred
(
&
sh
[
c_sh_wr
+
c_sh_wr_delta
*
i
],
&
sh
_red
[
c_sh_wr
+
c_sh_wr_delta
*
i
],
&
C
[
c_gl_wr
+
c_gl_wr_delta_o
*
(
i
/
2
)
+
&
C
[
c_gl_wr
+
c_gl_wr_delta_o
*
(
i
/
2
)
+
c_gl_wr_delta_i
*
(
i
%
2
)],
c_gl_wr_delta_i
*
(
i
%
2
)],
i
<
(
thread_m_blocks
-
1
)
*
4
||
8
*
(
i
/
2
)
+
row
<
prob_m
);
i
<
(
thread_m_blocks
-
1
)
*
4
||
8
*
(
i
/
2
)
+
row
<
prob_m
);
...
@@ -1410,7 +1409,7 @@ __global__ void Marlin(
...
@@ -1410,7 +1409,7 @@ __global__ void Marlin(
for
(
int
i
=
0
;
i
<
thread_m_blocks
*
4
;
i
++
)
{
for
(
int
i
=
0
;
i
<
thread_m_blocks
*
4
;
i
++
)
{
if
(
i
<
(
thread_m_blocks
-
1
)
*
4
||
8
*
(
i
/
2
)
+
row
<
prob_m
)
{
if
(
i
<
(
thread_m_blocks
-
1
)
*
4
||
8
*
(
i
/
2
)
+
row
<
prob_m
)
{
if
(
!
first
)
{
if
(
!
first
)
{
int4
c_red
=
sh
[
c_sh_wr
+
i
*
c_sh_wr_delta
];
int4
c_red
=
sh
_red
[
c_sh_wr
+
i
*
c_sh_wr_delta
];
#pragma unroll
#pragma unroll
for
(
int
j
=
0
;
j
<
2
*
4
;
j
++
)
{
for
(
int
j
=
0
;
j
<
2
*
4
;
j
++
)
{
reinterpret_cast
<
float
*>
(
reinterpret_cast
<
float
*>
(
...
@@ -1461,10 +1460,10 @@ __global__ void Marlin(
...
@@ -1461,10 +1460,10 @@ __global__ void Marlin(
float
*
frag_c_ptr
=
reinterpret_cast
<
float
*>
(
&
frag_c
);
float
*
frag_c_ptr
=
reinterpret_cast
<
float
*>
(
&
frag_c
);
#pragma unroll
#pragma unroll
for
(
int
k
=
0
;
k
<
th_size
;
k
++
)
{
for
(
int
k
=
0
;
k
<
th_size
;
k
++
)
{
sh
[
threadIdx
.
x
]
=
sh
_red
[
threadIdx
.
x
]
=
C_tmp
[
c_cur_offset
+
active_threads
*
k
+
threadIdx
.
x
];
C_tmp
[
c_cur_offset
+
active_threads
*
k
+
threadIdx
.
x
];
float
*
sh_c_ptr
=
reinterpret_cast
<
float
*>
(
&
sh
[
threadIdx
.
x
]);
float
*
sh_c_ptr
=
reinterpret_cast
<
float
*>
(
&
sh
_red
[
threadIdx
.
x
]);
#pragma unroll
#pragma unroll
for
(
int
f
=
0
;
f
<
4
;
f
++
)
{
for
(
int
f
=
0
;
f
<
4
;
f
++
)
{
frag_c_ptr
[
k
*
4
+
f
]
+=
sh_c_ptr
[
f
];
frag_c_ptr
[
k
*
4
+
f
]
+=
sh_c_ptr
[
f
];
...
@@ -1515,7 +1514,7 @@ __global__ void Marlin(
...
@@ -1515,7 +1514,7 @@ __global__ void Marlin(
res
=
__hmul2
(
res
,
s
[
0
]);
res
=
__hmul2
(
res
,
s
[
0
]);
}
}
((
scalar_t2
*
)
sh
)[
idx
]
=
res
;
((
scalar_t2
*
)
sh
_red
)[
idx
]
=
res
;
};
};
if
(
threadIdx
.
x
/
32
<
thread_n_blocks
/
4
)
{
if
(
threadIdx
.
x
/
32
<
thread_n_blocks
/
4
)
{
...
@@ -1543,7 +1542,7 @@ __global__ void Marlin(
...
@@ -1543,7 +1542,7 @@ __global__ void Marlin(
i
<
div_ceil
(
16
*
thread_m_blocks
,
threads
/
(
2
*
thread_n_blocks
));
i
<
div_ceil
(
16
*
thread_m_blocks
,
threads
/
(
2
*
thread_n_blocks
));
i
++
)
{
i
++
)
{
if
(
c_gl_wr
<
c_gl_wr_end
)
{
if
(
c_gl_wr
<
c_gl_wr_end
)
{
C
[
c_gl_wr
]
=
sh
[
c_sh_rd
];
C
[
c_gl_wr
]
=
sh
_red
[
c_sh_rd
];
c_gl_wr
+=
c_gl_wr_delta
;
c_gl_wr
+=
c_gl_wr_delta
;
c_sh_rd
+=
c_sh_rd_delta
;
c_sh_rd
+=
c_sh_rd_delta
;
}
}
...
@@ -1865,9 +1864,12 @@ bool is_valid_cache_size(thread_config_t const& th_config, int max_m_blocks,
...
@@ -1865,9 +1864,12 @@ bool is_valid_cache_size(thread_config_t const& th_config, int max_m_blocks,
float
pipe_size
=
(
a_size
+
b_size
)
*
pipe_stages
;
float
pipe_size
=
(
a_size
+
b_size
)
*
pipe_stages
;
float
reduce_size
=
max
(
th_config
.
num_threads
*
32
*
4
,
(
tb_n
/
64
)
*
32
*
(
tb_max_m
/
16
)
*
4
*
2
*
4
*
2
);
TORCH_CHECK
(
max_shared_mem
/
2
>
scales_cache_size
);
// Sanity
TORCH_CHECK
(
max_shared_mem
/
2
>
scales_cache_size
);
// Sanity
return
pipe_size
<
0.95
f
*
(
max_shared_mem
-
scales_cache_size
);
return
pipe_size
+
reduce_size
<
0.95
f
*
(
max_shared_mem
-
scales_cache_size
);
}
}
bool
is_valid_config
(
thread_config_t
const
&
th_config
,
int
max_m_blocks
,
bool
is_valid_config
(
thread_config_t
const
&
th_config
,
int
max_m_blocks
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment